forked from phoenix/litellm-mirror
Litellm ruff linting enforcement (#5992)
* ci(config.yml): add a 'check_code_quality' step Addresses https://github.com/BerriAI/litellm/issues/5991 * ci(config.yml): check why circle ci doesn't pick up this test * ci(config.yml): fix to run 'check_code_quality' tests * fix(__init__.py): fix unprotected import * fix(__init__.py): don't remove unused imports * build(ruff.toml): update ruff.toml to ignore unused imports * fix: fix: ruff + pyright - fix linting + type-checking errors * fix: fix linting errors * fix(lago.py): fix module init error * fix: fix linting errors * ci(config.yml): cd into correct dir for checks * fix(proxy_server.py): fix linting error * fix(utils.py): fix bare except causes ruff linting errors * fix: ruff - fix remaining linting errors * fix(clickhouse.py): use standard logging object * fix(__init__.py): fix unprotected import * fix: ruff - fix linting errors * fix: fix linting errors * ci(config.yml): cleanup code qa step (formatting handled in local_testing) * fix(_health_endpoints.py): fix ruff linting errors * ci(config.yml): just use ruff in check_code_quality pipeline for now * build(custom_guardrail.py): include missing file * style(embedding_handler.py): fix ruff check
This commit is contained in:
parent
3fc4ae0d65
commit
d57be47b0f
263 changed files with 1687 additions and 3320 deletions
|
@ -299,6 +299,27 @@ jobs:
|
||||||
ls
|
ls
|
||||||
python -m pytest -vv tests/local_testing/test_python_38.py
|
python -m pytest -vv tests/local_testing/test_python_38.py
|
||||||
|
|
||||||
|
check_code_quality:
|
||||||
|
docker:
|
||||||
|
- image: cimg/python:3.11
|
||||||
|
auth:
|
||||||
|
username: ${DOCKERHUB_USERNAME}
|
||||||
|
password: ${DOCKERHUB_PASSWORD}
|
||||||
|
working_directory: ~/project/litellm
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install Dependencies
|
||||||
|
command: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install ruff
|
||||||
|
pip install pylint
|
||||||
|
pip install .
|
||||||
|
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
|
||||||
|
- run: ruff check ./litellm
|
||||||
|
|
||||||
|
|
||||||
build_and_test:
|
build_and_test:
|
||||||
machine:
|
machine:
|
||||||
image: ubuntu-2204:2023.10.1
|
image: ubuntu-2204:2023.10.1
|
||||||
|
@ -806,6 +827,12 @@ workflows:
|
||||||
only:
|
only:
|
||||||
- main
|
- main
|
||||||
- /litellm_.*/
|
- /litellm_.*/
|
||||||
|
- check_code_quality:
|
||||||
|
filters:
|
||||||
|
branches:
|
||||||
|
only:
|
||||||
|
- main
|
||||||
|
- /litellm_.*/
|
||||||
- ui_endpoint_testing:
|
- ui_endpoint_testing:
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
|
@ -867,6 +894,7 @@ workflows:
|
||||||
- installing_litellm_on_python
|
- installing_litellm_on_python
|
||||||
- proxy_logging_guardrails_model_info_tests
|
- proxy_logging_guardrails_model_info_tests
|
||||||
- proxy_pass_through_endpoint_tests
|
- proxy_pass_through_endpoint_tests
|
||||||
|
- check_code_quality
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
|
|
|
@ -630,18 +630,6 @@ general_settings:
|
||||||
"database_url": "string",
|
"database_url": "string",
|
||||||
"database_connection_pool_limit": 0, # default 100
|
"database_connection_pool_limit": 0, # default 100
|
||||||
"database_connection_timeout": 0, # default 60s
|
"database_connection_timeout": 0, # default 60s
|
||||||
"database_type": "dynamo_db",
|
|
||||||
"database_args": {
|
|
||||||
"billing_mode": "PROVISIONED_THROUGHPUT",
|
|
||||||
"read_capacity_units": 0,
|
|
||||||
"write_capacity_units": 0,
|
|
||||||
"ssl_verify": true,
|
|
||||||
"region_name": "string",
|
|
||||||
"user_table_name": "LiteLLM_UserTable",
|
|
||||||
"key_table_name": "LiteLLM_VerificationToken",
|
|
||||||
"config_table_name": "LiteLLM_Config",
|
|
||||||
"spend_table_name": "LiteLLM_SpendLogs"
|
|
||||||
},
|
|
||||||
"otel": true,
|
"otel": true,
|
||||||
"custom_auth": "string",
|
"custom_auth": "string",
|
||||||
"max_parallel_requests": 0, # the max parallel requests allowed per deployment
|
"max_parallel_requests": 0, # the max parallel requests allowed per deployment
|
||||||
|
|
|
@ -97,7 +97,7 @@ class GenericAPILogger:
|
||||||
for key, value in payload.items():
|
for key, value in payload.items():
|
||||||
try:
|
try:
|
||||||
payload[key] = str(value)
|
payload[key] = str(value)
|
||||||
except:
|
except Exception:
|
||||||
# non blocking if it can't cast to a str
|
# non blocking if it can't cast to a str
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
from google.cloud import language_v1
|
from google.cloud import language_v1
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
|
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
|
||||||
)
|
)
|
||||||
|
@ -90,7 +90,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
||||||
verbose_proxy_logger.debug(print_statement)
|
verbose_proxy_logger.debug(print_statement)
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
|
|
|
@ -58,7 +58,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
||||||
verbose_proxy_logger.debug(print_statement)
|
verbose_proxy_logger.debug(print_statement)
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_custom_prompt_template(self, messages: list):
|
def set_custom_prompt_template(self, messages: list):
|
||||||
|
|
|
@ -49,7 +49,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
verbose_proxy_logger.debug(print_statement)
|
verbose_proxy_logger.debug(print_statement)
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def moderation_check(self, text: str):
|
async def moderation_check(self, text: str):
|
||||||
|
|
|
@ -3,7 +3,8 @@ import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading
|
||||||
|
import os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
|
@ -308,13 +309,13 @@ def get_model_cost_map(url: str):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with requests.get(
|
response = httpx.get(
|
||||||
url, timeout=5
|
url, timeout=5
|
||||||
) as response: # set a 5 second timeout for the get request
|
) # set a 5 second timeout for the get request
|
||||||
response.raise_for_status() # Raise an exception if the request is unsuccessful
|
response.raise_for_status() # Raise an exception if the request is unsuccessful
|
||||||
content = response.json()
|
content = response.json()
|
||||||
return content
|
return content
|
||||||
except Exception as e:
|
except Exception:
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -839,7 +840,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
||||||
|
|
||||||
from .timeout import timeout
|
from .timeout import timeout
|
||||||
from .cost_calculator import completion_cost
|
from .cost_calculator import completion_cost
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging, modify_integration
|
||||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||||
from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls
|
from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls
|
||||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||||
|
@ -848,7 +849,6 @@ from .utils import (
|
||||||
exception_type,
|
exception_type,
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
get_response_string,
|
get_response_string,
|
||||||
modify_integration,
|
|
||||||
token_counter,
|
token_counter,
|
||||||
create_pretrained_tokenizer,
|
create_pretrained_tokenizer,
|
||||||
create_tokenizer,
|
create_tokenizer,
|
||||||
|
|
|
@ -98,5 +98,5 @@ def print_verbose(print_statement):
|
||||||
try:
|
try:
|
||||||
if set_verbose:
|
if set_verbose:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -2,5 +2,5 @@ import importlib_metadata
|
||||||
|
|
||||||
try:
|
try:
|
||||||
version = importlib_metadata.version("litellm")
|
version = importlib_metadata.version("litellm")
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -12,7 +12,6 @@ from openai.types.beta.assistant import Assistant
|
||||||
from openai.types.beta.assistant_deleted import AssistantDeleted
|
from openai.types.beta.assistant_deleted import AssistantDeleted
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import client
|
|
||||||
from litellm.llms.AzureOpenAI import assistants
|
from litellm.llms.AzureOpenAI import assistants
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -96,7 +95,7 @@ def get_assistants(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -280,7 +279,7 @@ def create_assistants(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -464,7 +463,7 @@ def delete_assistant(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -649,7 +648,7 @@ def create_thread(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -805,7 +804,7 @@ def get_thread(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -991,7 +990,7 @@ def add_message(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -1149,7 +1148,7 @@ def get_messages(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -1347,7 +1346,7 @@ def run_thread(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
|
|
@ -22,7 +22,7 @@ import litellm
|
||||||
from litellm import client
|
from litellm import client
|
||||||
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
|
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
|
||||||
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
|
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
Batch,
|
Batch,
|
||||||
CancelBatchRequest,
|
CancelBatchRequest,
|
||||||
|
@ -131,7 +131,7 @@ def create_batch(
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
api_base: Optional[str] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
@ -165,27 +165,30 @@ def create_batch(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = (
|
||||||
|
optional_params.api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret_str("AZURE_API_BASE")
|
||||||
|
)
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_batches_instance.create_batch(
|
response = azure_batches_instance.create_batch(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -293,7 +296,7 @@ def retrieve_batch(
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
||||||
|
api_base: Optional[str] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
@ -327,27 +330,30 @@ def retrieve_batch(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = (
|
||||||
|
optional_params.api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret_str("AZURE_API_BASE")
|
||||||
|
)
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_batches_instance.retrieve_batch(
|
response = azure_batches_instance.retrieve_batch(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -384,7 +390,7 @@ async def alist_batches(
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Batch:
|
):
|
||||||
"""
|
"""
|
||||||
Async: List your organization's batches.
|
Async: List your organization's batches.
|
||||||
"""
|
"""
|
||||||
|
@ -482,27 +488,26 @@ def list_batches(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_batches_instance.list_batches(
|
response = azure_batches_instance.list_batches(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
|
|
@ -7,11 +7,16 @@
|
||||||
#
|
#
|
||||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
import os, json, time
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
import requests, threading # type: ignore
|
|
||||||
from typing import Optional, Union, Literal
|
|
||||||
|
|
||||||
|
|
||||||
class BudgetManager:
|
class BudgetManager:
|
||||||
|
@ -35,7 +40,7 @@ class BudgetManager:
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.info(print_statement)
|
logging.info(print_statement)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
|
@ -52,7 +57,6 @@ class BudgetManager:
|
||||||
elif self.client_type == "hosted":
|
elif self.client_type == "hosted":
|
||||||
# Load the user_dict from hosted db
|
# Load the user_dict from hosted db
|
||||||
url = self.api_base + "/get_budget"
|
url = self.api_base + "/get_budget"
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
data = {"project_name": self.project_name}
|
data = {"project_name": self.project_name}
|
||||||
response = requests.post(url, headers=self.headers, json=data)
|
response = requests.post(url, headers=self.headers, json=data)
|
||||||
response = response.json()
|
response = response.json()
|
||||||
|
@ -210,7 +214,6 @@ class BudgetManager:
|
||||||
return {"status": "success"}
|
return {"status": "success"}
|
||||||
elif self.client_type == "hosted":
|
elif self.client_type == "hosted":
|
||||||
url = self.api_base + "/set_budget"
|
url = self.api_base + "/set_budget"
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
||||||
response = requests.post(url, headers=self.headers, json=data)
|
response = requests.post(url, headers=self.headers, json=data)
|
||||||
response = response.json()
|
response = response.json()
|
||||||
|
|
|
@ -33,7 +33,7 @@ def print_verbose(print_statement):
|
||||||
verbose_logger.debug(print_statement)
|
verbose_logger.debug(print_statement)
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,15 +96,13 @@ class InMemoryCache(BaseCache):
|
||||||
"""
|
"""
|
||||||
for key in list(self.ttl_dict.keys()):
|
for key in list(self.ttl_dict.keys()):
|
||||||
if time.time() > self.ttl_dict[key]:
|
if time.time() > self.ttl_dict[key]:
|
||||||
removed_item = self.cache_dict.pop(key, None)
|
self.cache_dict.pop(key, None)
|
||||||
removed_ttl_item = self.ttl_dict.pop(key, None)
|
self.ttl_dict.pop(key, None)
|
||||||
|
|
||||||
# de-reference the removed item
|
# de-reference the removed item
|
||||||
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
|
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
|
||||||
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
|
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
|
||||||
# This can occur when an object is referenced by another object, but the reference is never removed.
|
# This can occur when an object is referenced by another object, but the reference is never removed.
|
||||||
removed_item = None
|
|
||||||
removed_ttl_item = None
|
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -150,7 +148,7 @@ class InMemoryCache(BaseCache):
|
||||||
original_cached_response = self.cache_dict[key]
|
original_cached_response = self.cache_dict[key]
|
||||||
try:
|
try:
|
||||||
cached_response = json.loads(original_cached_response)
|
cached_response = json.loads(original_cached_response)
|
||||||
except:
|
except Exception:
|
||||||
cached_response = original_cached_response
|
cached_response = original_cached_response
|
||||||
return cached_response
|
return cached_response
|
||||||
return None
|
return None
|
||||||
|
@ -251,7 +249,7 @@ class RedisCache(BaseCache):
|
||||||
self.redis_version = "Unknown"
|
self.redis_version = "Unknown"
|
||||||
try:
|
try:
|
||||||
self.redis_version = self.redis_client.info()["redis_version"]
|
self.redis_version = self.redis_client.info()["redis_version"]
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
### ASYNC HEALTH PING ###
|
### ASYNC HEALTH PING ###
|
||||||
|
@ -688,7 +686,7 @@ class RedisCache(BaseCache):
|
||||||
cached_response = json.loads(
|
cached_response = json.loads(
|
||||||
cached_response
|
cached_response
|
||||||
) # Convert string to dictionary
|
) # Convert string to dictionary
|
||||||
except:
|
except Exception:
|
||||||
cached_response = ast.literal_eval(cached_response)
|
cached_response = ast.literal_eval(cached_response)
|
||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
|
@ -844,7 +842,7 @@ class RedisCache(BaseCache):
|
||||||
"""
|
"""
|
||||||
Tests if the sync redis client is correctly setup.
|
Tests if the sync redis client is correctly setup.
|
||||||
"""
|
"""
|
||||||
print_verbose(f"Pinging Sync Redis Cache")
|
print_verbose("Pinging Sync Redis Cache")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
response = self.redis_client.ping()
|
response = self.redis_client.ping()
|
||||||
|
@ -878,7 +876,7 @@ class RedisCache(BaseCache):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
print_verbose(f"Pinging Async Redis Cache")
|
print_verbose("Pinging Async Redis Cache")
|
||||||
try:
|
try:
|
||||||
response = await redis_client.ping()
|
response = await redis_client.ping()
|
||||||
## LOGGING ##
|
## LOGGING ##
|
||||||
|
@ -973,7 +971,6 @@ class RedisSemanticCache(BaseCache):
|
||||||
},
|
},
|
||||||
"fields": {
|
"fields": {
|
||||||
"text": [{"name": "response"}],
|
"text": [{"name": "response"}],
|
||||||
"text": [{"name": "prompt"}],
|
|
||||||
"vector": [
|
"vector": [
|
||||||
{
|
{
|
||||||
"name": "litellm_embedding",
|
"name": "litellm_embedding",
|
||||||
|
@ -999,14 +996,14 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
redis_url = "redis://:" + password + "@" + host + ":" + port
|
redis_url = "redis://:" + password + "@" + host + ":" + port
|
||||||
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
|
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
|
||||||
if use_async == False:
|
if use_async is False:
|
||||||
self.index = SearchIndex.from_dict(schema)
|
self.index = SearchIndex.from_dict(schema)
|
||||||
self.index.connect(redis_url=redis_url)
|
self.index.connect(redis_url=redis_url)
|
||||||
try:
|
try:
|
||||||
self.index.create(overwrite=False) # don't overwrite existing index
|
self.index.create(overwrite=False) # don't overwrite existing index
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
|
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
|
||||||
elif use_async == True:
|
elif use_async is True:
|
||||||
schema["index"]["name"] = "litellm_semantic_cache_index_async"
|
schema["index"]["name"] = "litellm_semantic_cache_index_async"
|
||||||
self.index = SearchIndex.from_dict(schema)
|
self.index = SearchIndex.from_dict(schema)
|
||||||
self.index.connect(redis_url=redis_url, use_async=True)
|
self.index.connect(redis_url=redis_url, use_async=True)
|
||||||
|
@ -1027,7 +1024,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
cached_response = json.loads(
|
cached_response = json.loads(
|
||||||
cached_response
|
cached_response
|
||||||
) # Convert string to dictionary
|
) # Convert string to dictionary
|
||||||
except:
|
except Exception:
|
||||||
cached_response = ast.literal_eval(cached_response)
|
cached_response = ast.literal_eval(cached_response)
|
||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
|
@ -1060,7 +1057,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add more data
|
# Add more data
|
||||||
keys = self.index.load(new_data)
|
self.index.load(new_data)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1092,7 +1089,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.index.query(query)
|
results = self.index.query(query)
|
||||||
if results == None:
|
if results is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(results, list):
|
if isinstance(results, list):
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
|
@ -1173,7 +1170,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add more data
|
# Add more data
|
||||||
keys = await self.index.aload(new_data)
|
await self.index.aload(new_data)
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
@ -1222,7 +1219,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
return_fields=["response", "prompt", "vector_distance"],
|
return_fields=["response", "prompt", "vector_distance"],
|
||||||
)
|
)
|
||||||
results = await self.index.aquery(query)
|
results = await self.index.aquery(query)
|
||||||
if results == None:
|
if results is None:
|
||||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||||
return None
|
return None
|
||||||
if isinstance(results, list):
|
if isinstance(results, list):
|
||||||
|
@ -1396,7 +1393,7 @@ class QdrantSemanticCache(BaseCache):
|
||||||
cached_response = json.loads(
|
cached_response = json.loads(
|
||||||
cached_response
|
cached_response
|
||||||
) # Convert string to dictionary
|
) # Convert string to dictionary
|
||||||
except:
|
except Exception:
|
||||||
cached_response = ast.literal_eval(cached_response)
|
cached_response = ast.literal_eval(cached_response)
|
||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
|
@ -1435,7 +1432,7 @@ class QdrantSemanticCache(BaseCache):
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
keys = self.sync_client.put(
|
self.sync_client.put(
|
||||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=data,
|
json=data,
|
||||||
|
@ -1481,7 +1478,7 @@ class QdrantSemanticCache(BaseCache):
|
||||||
)
|
)
|
||||||
results = search_response.json()["result"]
|
results = search_response.json()["result"]
|
||||||
|
|
||||||
if results == None:
|
if results is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(results, list):
|
if isinstance(results, list):
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
|
@ -1563,7 +1560,7 @@ class QdrantSemanticCache(BaseCache):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
keys = await self.async_client.put(
|
await self.async_client.put(
|
||||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=data,
|
json=data,
|
||||||
|
@ -1629,7 +1626,7 @@ class QdrantSemanticCache(BaseCache):
|
||||||
|
|
||||||
results = search_response.json()["result"]
|
results = search_response.json()["result"]
|
||||||
|
|
||||||
if results == None:
|
if results is None:
|
||||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||||
return None
|
return None
|
||||||
if isinstance(results, list):
|
if isinstance(results, list):
|
||||||
|
@ -1767,7 +1764,7 @@ class S3Cache(BaseCache):
|
||||||
cached_response = json.loads(
|
cached_response = json.loads(
|
||||||
cached_response
|
cached_response
|
||||||
) # Convert string to dictionary
|
) # Convert string to dictionary
|
||||||
except Exception as e:
|
except Exception:
|
||||||
cached_response = ast.literal_eval(cached_response)
|
cached_response = ast.literal_eval(cached_response)
|
||||||
if type(cached_response) is not dict:
|
if type(cached_response) is not dict:
|
||||||
cached_response = dict(cached_response)
|
cached_response = dict(cached_response)
|
||||||
|
@ -1845,7 +1842,7 @@ class DualCache(BaseCache):
|
||||||
|
|
||||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||||
|
|
||||||
if self.redis_cache is not None and local_only == False:
|
if self.redis_cache is not None and local_only is False:
|
||||||
self.redis_cache.set_cache(key, value, **kwargs)
|
self.redis_cache.set_cache(key, value, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(e)
|
print_verbose(e)
|
||||||
|
@ -1865,7 +1862,7 @@ class DualCache(BaseCache):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
||||||
|
|
||||||
if self.redis_cache is not None and local_only == False:
|
if self.redis_cache is not None and local_only is False:
|
||||||
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
@ -1887,7 +1884,7 @@ class DualCache(BaseCache):
|
||||||
if (
|
if (
|
||||||
(self.always_read_redis is True)
|
(self.always_read_redis is True)
|
||||||
and self.redis_cache is not None
|
and self.redis_cache is not None
|
||||||
and local_only == False
|
and local_only is False
|
||||||
):
|
):
|
||||||
# If not found in in-memory cache or always_read_redis is True, try fetching from Redis
|
# If not found in in-memory cache or always_read_redis is True, try fetching from Redis
|
||||||
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
||||||
|
@ -1900,7 +1897,7 @@ class DualCache(BaseCache):
|
||||||
|
|
||||||
print_verbose(f"get cache: cache result: {result}")
|
print_verbose(f"get cache: cache result: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.error(traceback.format_exc())
|
verbose_logger.error(traceback.format_exc())
|
||||||
|
|
||||||
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
|
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
|
||||||
|
@ -1913,7 +1910,7 @@ class DualCache(BaseCache):
|
||||||
if in_memory_result is not None:
|
if in_memory_result is not None:
|
||||||
result = in_memory_result
|
result = in_memory_result
|
||||||
|
|
||||||
if None in result and self.redis_cache is not None and local_only == False:
|
if None in result and self.redis_cache is not None and local_only is False:
|
||||||
"""
|
"""
|
||||||
- for the none values in the result
|
- for the none values in the result
|
||||||
- check the redis cache
|
- check the redis cache
|
||||||
|
@ -1933,7 +1930,7 @@ class DualCache(BaseCache):
|
||||||
|
|
||||||
print_verbose(f"async batch get cache: cache result: {result}")
|
print_verbose(f"async batch get cache: cache result: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.error(traceback.format_exc())
|
verbose_logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||||
|
@ -1952,7 +1949,7 @@ class DualCache(BaseCache):
|
||||||
if in_memory_result is not None:
|
if in_memory_result is not None:
|
||||||
result = in_memory_result
|
result = in_memory_result
|
||||||
|
|
||||||
if result is None and self.redis_cache is not None and local_only == False:
|
if result is None and self.redis_cache is not None and local_only is False:
|
||||||
# If not found in in-memory cache, try fetching from Redis
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
redis_result = await self.redis_cache.async_get_cache(key, **kwargs)
|
redis_result = await self.redis_cache.async_get_cache(key, **kwargs)
|
||||||
|
|
||||||
|
@ -1966,7 +1963,7 @@ class DualCache(BaseCache):
|
||||||
|
|
||||||
print_verbose(f"get cache: cache result: {result}")
|
print_verbose(f"get cache: cache result: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.error(traceback.format_exc())
|
verbose_logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def async_batch_get_cache(
|
async def async_batch_get_cache(
|
||||||
|
@ -1981,7 +1978,7 @@ class DualCache(BaseCache):
|
||||||
|
|
||||||
if in_memory_result is not None:
|
if in_memory_result is not None:
|
||||||
result = in_memory_result
|
result = in_memory_result
|
||||||
if None in result and self.redis_cache is not None and local_only == False:
|
if None in result and self.redis_cache is not None and local_only is False:
|
||||||
"""
|
"""
|
||||||
- for the none values in the result
|
- for the none values in the result
|
||||||
- check the redis cache
|
- check the redis cache
|
||||||
|
@ -2006,7 +2003,7 @@ class DualCache(BaseCache):
|
||||||
result[index] = value
|
result[index] = value
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.error(traceback.format_exc())
|
verbose_logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||||
|
@ -2017,7 +2014,7 @@ class DualCache(BaseCache):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||||
|
|
||||||
if self.redis_cache is not None and local_only == False:
|
if self.redis_cache is not None and local_only is False:
|
||||||
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.exception(
|
verbose_logger.exception(
|
||||||
|
@ -2039,7 +2036,7 @@ class DualCache(BaseCache):
|
||||||
cache_list=cache_list, **kwargs
|
cache_list=cache_list, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.redis_cache is not None and local_only == False:
|
if self.redis_cache is not None and local_only is False:
|
||||||
await self.redis_cache.async_set_cache_pipeline(
|
await self.redis_cache.async_set_cache_pipeline(
|
||||||
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
|
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
|
||||||
)
|
)
|
||||||
|
@ -2459,7 +2456,7 @@ class Cache:
|
||||||
cached_response = json.loads(
|
cached_response = json.loads(
|
||||||
cached_response # type: ignore
|
cached_response # type: ignore
|
||||||
) # Convert string to dictionary
|
) # Convert string to dictionary
|
||||||
except:
|
except Exception:
|
||||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||||
return cached_response
|
return cached_response
|
||||||
return cached_result
|
return cached_result
|
||||||
|
@ -2492,7 +2489,7 @@ class Cache:
|
||||||
return self._get_cache_logic(
|
return self._get_cache_logic(
|
||||||
cached_result=cached_result, max_age=max_age
|
cached_result=cached_result, max_age=max_age
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -2506,7 +2503,7 @@ class Cache:
|
||||||
if self.should_use_cache(*args, **kwargs) is not True:
|
if self.should_use_cache(*args, **kwargs) is not True:
|
||||||
return
|
return
|
||||||
|
|
||||||
messages = kwargs.get("messages", [])
|
kwargs.get("messages", [])
|
||||||
if "cache_key" in kwargs:
|
if "cache_key" in kwargs:
|
||||||
cache_key = kwargs["cache_key"]
|
cache_key = kwargs["cache_key"]
|
||||||
else:
|
else:
|
||||||
|
@ -2522,7 +2519,7 @@ class Cache:
|
||||||
return self._get_cache_logic(
|
return self._get_cache_logic(
|
||||||
cached_result=cached_result, max_age=max_age
|
cached_result=cached_result, max_age=max_age
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -2701,7 +2698,7 @@ class DiskCache(BaseCache):
|
||||||
if original_cached_response:
|
if original_cached_response:
|
||||||
try:
|
try:
|
||||||
cached_response = json.loads(original_cached_response) # type: ignore
|
cached_response = json.loads(original_cached_response) # type: ignore
|
||||||
except:
|
except Exception:
|
||||||
cached_response = original_cached_response
|
cached_response = original_cached_response
|
||||||
return cached_response
|
return cached_response
|
||||||
return None
|
return None
|
||||||
|
@ -2803,7 +2800,7 @@ def enable_cache(
|
||||||
if "cache" not in litellm._async_success_callback:
|
if "cache" not in litellm._async_success_callback:
|
||||||
litellm._async_success_callback.append("cache")
|
litellm._async_success_callback.append("cache")
|
||||||
|
|
||||||
if litellm.cache == None:
|
if litellm.cache is None:
|
||||||
litellm.cache = Cache(
|
litellm.cache = Cache(
|
||||||
type=type,
|
type=type,
|
||||||
host=host,
|
host=host,
|
||||||
|
|
|
@ -57,7 +57,7 @@
|
||||||
# config = yaml.safe_load(file)
|
# config = yaml.safe_load(file)
|
||||||
# else:
|
# else:
|
||||||
# pass
|
# pass
|
||||||
# except:
|
# except Exception:
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
# ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
|
# ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
|
||||||
|
|
|
@ -9,12 +9,12 @@ import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import client, get_secret
|
from litellm import client, get_secret_str
|
||||||
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
|
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
|
||||||
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
@ -39,7 +39,7 @@ async def afile_retrieve(
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Coroutine[Any, Any, FileObject]:
|
):
|
||||||
"""
|
"""
|
||||||
Async: Get file contents
|
Async: Get file contents
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ async def afile_retrieve(
|
||||||
if asyncio.iscoroutine(init_response):
|
if asyncio.iscoroutine(init_response):
|
||||||
response = await init_response
|
response = await init_response
|
||||||
else:
|
else:
|
||||||
response = init_response # type: ignore
|
response = init_response
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -137,27 +137,26 @@ def file_retrieve(
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_files_instance.retrieve_file(
|
response = azure_files_instance.retrieve_file(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -181,7 +180,7 @@ def file_retrieve(
|
||||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return response
|
return cast(FileObject, response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -222,7 +221,7 @@ async def afile_delete(
|
||||||
else:
|
else:
|
||||||
response = init_response # type: ignore
|
response = init_response # type: ignore
|
||||||
|
|
||||||
return response
|
return cast(FileDeleted, response) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -248,7 +247,7 @@ def file_delete(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -288,27 +287,26 @@ def file_delete(
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_files_instance.delete_file(
|
response = azure_files_instance.delete_file(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -332,7 +330,7 @@ def file_delete(
|
||||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return response
|
return cast(FileDeleted, response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -399,7 +397,7 @@ def file_list(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -441,27 +439,26 @@ def file_list(
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_files_instance.list_files(
|
response = azure_files_instance.list_files(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -556,7 +553,7 @@ def create_file(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -603,27 +600,26 @@ def create_file(
|
||||||
create_file_data=_create_file_request,
|
create_file_data=_create_file_request,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_files_instance.create_file(
|
response = azure_files_instance.create_file(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -713,7 +709,7 @@ def file_content(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -761,27 +757,26 @@ def file_content(
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_files_instance.file_content(
|
response = azure_files_instance.file_content(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
|
|
@ -25,7 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
|
||||||
OpenAIFineTuningAPI,
|
OpenAIFineTuningAPI,
|
||||||
)
|
)
|
||||||
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
|
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import Hyperparameters
|
from litellm.types.llms.openai import Hyperparameters
|
||||||
from litellm.types.router import *
|
from litellm.types.router import *
|
||||||
from litellm.utils import supports_httpx_timeout
|
from litellm.utils import supports_httpx_timeout
|
||||||
|
@ -119,7 +119,7 @@ def create_fine_tuning_job(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -177,28 +177,27 @@ def create_fine_tuning_job(
|
||||||
)
|
)
|
||||||
# Azure OpenAI
|
# Azure OpenAI
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||||
|
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -228,14 +227,14 @@ def create_fine_tuning_job(
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
optional_params.vertex_project
|
optional_params.vertex_project
|
||||||
or litellm.vertex_project
|
or litellm.vertex_project
|
||||||
or get_secret("VERTEXAI_PROJECT")
|
or get_secret_str("VERTEXAI_PROJECT")
|
||||||
)
|
)
|
||||||
vertex_ai_location = (
|
vertex_ai_location = (
|
||||||
optional_params.vertex_location
|
optional_params.vertex_location
|
||||||
or litellm.vertex_location
|
or litellm.vertex_location
|
||||||
or get_secret("VERTEXAI_LOCATION")
|
or get_secret_str("VERTEXAI_LOCATION")
|
||||||
)
|
)
|
||||||
vertex_credentials = optional_params.vertex_credentials or get_secret(
|
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||||
"VERTEXAI_CREDENTIALS"
|
"VERTEXAI_CREDENTIALS"
|
||||||
)
|
)
|
||||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||||
|
@ -315,7 +314,7 @@ async def acancel_fine_tuning_job(
|
||||||
|
|
||||||
def cancel_fine_tuning_job(
|
def cancel_fine_tuning_job(
|
||||||
fine_tuning_job_id: str,
|
fine_tuning_job_id: str,
|
||||||
custom_llm_provider: Literal["openai"] = "openai",
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -335,7 +334,7 @@ def cancel_fine_tuning_job(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -386,23 +385,22 @@ def cancel_fine_tuning_job(
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -438,7 +436,7 @@ async def alist_fine_tuning_jobs(
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> FineTuningJob:
|
):
|
||||||
"""
|
"""
|
||||||
Async: List your organization's fine-tuning jobs
|
Async: List your organization's fine-tuning jobs
|
||||||
"""
|
"""
|
||||||
|
@ -473,7 +471,7 @@ async def alist_fine_tuning_jobs(
|
||||||
def list_fine_tuning_jobs(
|
def list_fine_tuning_jobs(
|
||||||
after: Optional[str] = None,
|
after: Optional[str] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
custom_llm_provider: Literal["openai"] = "openai",
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -495,7 +493,7 @@ def list_fine_tuning_jobs(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
@ -542,28 +540,27 @@ def list_fine_tuning_jobs(
|
||||||
)
|
)
|
||||||
# Azure OpenAI
|
# Azure OpenAI
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||||
|
|
||||||
api_version = (
|
api_version = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret_str("AZURE_API_VERSION")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
or litellm.azure_key
|
or litellm.azure_key
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret_str("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token: Optional[str] = None
|
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
|
|
@ -23,6 +23,9 @@ import litellm.types
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||||
|
from litellm.litellm_core_utils.exception_mapping_utils import (
|
||||||
|
_add_key_name_and_team_to_alert,
|
||||||
|
)
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
|
@ -219,7 +222,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
and "metadata" in kwargs["litellm_params"]
|
and "metadata" in kwargs["litellm_params"]
|
||||||
):
|
):
|
||||||
_metadata: dict = kwargs["litellm_params"]["metadata"]
|
_metadata: dict = kwargs["litellm_params"]["metadata"]
|
||||||
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
request_info = _add_key_name_and_team_to_alert(
|
||||||
request_info=request_info, metadata=_metadata
|
request_info=request_info, metadata=_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -281,7 +284,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
return_val += 1
|
return_val += 1
|
||||||
|
|
||||||
return return_val
|
return return_val
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def send_daily_reports(self, router) -> bool:
|
async def send_daily_reports(self, router) -> bool:
|
||||||
|
@ -455,7 +458,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
try:
|
try:
|
||||||
messages = str(messages)
|
messages = str(messages)
|
||||||
messages = messages[:100]
|
messages = messages[:100]
|
||||||
except:
|
except Exception:
|
||||||
messages = ""
|
messages = ""
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -508,7 +511,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
_metadata: dict = request_data["metadata"]
|
_metadata: dict = request_data["metadata"]
|
||||||
_api_base = _metadata.get("api_base", "")
|
_api_base = _metadata.get("api_base", "")
|
||||||
|
|
||||||
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
request_info = _add_key_name_and_team_to_alert(
|
||||||
request_info=request_info, metadata=_metadata
|
request_info=request_info, metadata=_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -846,7 +849,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
|
|
||||||
## MINOR OUTAGE ALERT SENT ##
|
## MINOR OUTAGE ALERT SENT ##
|
||||||
if (
|
if (
|
||||||
outage_value["minor_alert_sent"] == False
|
outage_value["minor_alert_sent"] is False
|
||||||
and len(outage_value["alerts"])
|
and len(outage_value["alerts"])
|
||||||
>= self.alerting_args.minor_outage_alert_threshold
|
>= self.alerting_args.minor_outage_alert_threshold
|
||||||
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
|
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
|
||||||
|
@ -871,7 +874,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
|
|
||||||
## MAJOR OUTAGE ALERT SENT ##
|
## MAJOR OUTAGE ALERT SENT ##
|
||||||
elif (
|
elif (
|
||||||
outage_value["major_alert_sent"] == False
|
outage_value["major_alert_sent"] is False
|
||||||
and len(outage_value["alerts"])
|
and len(outage_value["alerts"])
|
||||||
>= self.alerting_args.major_outage_alert_threshold
|
>= self.alerting_args.major_outage_alert_threshold
|
||||||
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
|
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
|
||||||
|
@ -941,7 +944,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
if provider is None:
|
if provider is None:
|
||||||
try:
|
try:
|
||||||
model, provider, _, _ = litellm.get_llm_provider(model=model)
|
model, provider, _, _ = litellm.get_llm_provider(model=model)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
provider = ""
|
provider = ""
|
||||||
api_base = litellm.get_api_base(
|
api_base = litellm.get_api_base(
|
||||||
model=model, optional_params=deployment.litellm_params
|
model=model, optional_params=deployment.litellm_params
|
||||||
|
@ -976,7 +979,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
|
|
||||||
## MINOR OUTAGE ALERT SENT ##
|
## MINOR OUTAGE ALERT SENT ##
|
||||||
if (
|
if (
|
||||||
outage_value["minor_alert_sent"] == False
|
outage_value["minor_alert_sent"] is False
|
||||||
and len(outage_value["alerts"])
|
and len(outage_value["alerts"])
|
||||||
>= self.alerting_args.minor_outage_alert_threshold
|
>= self.alerting_args.minor_outage_alert_threshold
|
||||||
):
|
):
|
||||||
|
@ -998,7 +1001,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
# set to true
|
# set to true
|
||||||
outage_value["minor_alert_sent"] = True
|
outage_value["minor_alert_sent"] = True
|
||||||
elif (
|
elif (
|
||||||
outage_value["major_alert_sent"] == False
|
outage_value["major_alert_sent"] is False
|
||||||
and len(outage_value["alerts"])
|
and len(outage_value["alerts"])
|
||||||
>= self.alerting_args.major_outage_alert_threshold
|
>= self.alerting_args.major_outage_alert_threshold
|
||||||
):
|
):
|
||||||
|
@ -1024,7 +1027,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
await self.internal_usage_cache.async_set_cache(
|
await self.internal_usage_cache.async_set_cache(
|
||||||
key=deployment_id, value=outage_value
|
key=deployment_id, value=outage_value
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def model_added_alert(
|
async def model_added_alert(
|
||||||
|
@ -1177,7 +1180,6 @@ Model Info:
|
||||||
if user_row is not None:
|
if user_row is not None:
|
||||||
recipient_email = user_row.user_email
|
recipient_email = user_row.user_email
|
||||||
|
|
||||||
key_name = webhook_event.key_alias
|
|
||||||
key_token = webhook_event.token
|
key_token = webhook_event.token
|
||||||
key_budget = webhook_event.max_budget
|
key_budget = webhook_event.max_budget
|
||||||
base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
|
base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
|
||||||
|
@ -1221,14 +1223,14 @@ Model Info:
|
||||||
extra=webhook_event.model_dump(),
|
extra=webhook_event.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = webhook_event.model_dump_json()
|
webhook_event.model_dump_json()
|
||||||
email_event = {
|
email_event = {
|
||||||
"to": recipient_email,
|
"to": recipient_email,
|
||||||
"subject": f"LiteLLM: {event_name}",
|
"subject": f"LiteLLM: {event_name}",
|
||||||
"html": email_html_content,
|
"html": email_html_content,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await send_email(
|
await send_email(
|
||||||
receiver_email=email_event["to"],
|
receiver_email=email_event["to"],
|
||||||
subject=email_event["subject"],
|
subject=email_event["subject"],
|
||||||
html=email_event["html"],
|
html=email_event["html"],
|
||||||
|
@ -1292,14 +1294,14 @@ Model Info:
|
||||||
The LiteLLM team <br />
|
The LiteLLM team <br />
|
||||||
"""
|
"""
|
||||||
|
|
||||||
payload = webhook_event.model_dump_json()
|
webhook_event.model_dump_json()
|
||||||
email_event = {
|
email_event = {
|
||||||
"to": recipient_email,
|
"to": recipient_email,
|
||||||
"subject": f"LiteLLM: {event_name}",
|
"subject": f"LiteLLM: {event_name}",
|
||||||
"html": email_html_content,
|
"html": email_html_content,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await send_email(
|
await send_email(
|
||||||
receiver_email=email_event["to"],
|
receiver_email=email_event["to"],
|
||||||
subject=email_event["subject"],
|
subject=email_event["subject"],
|
||||||
html=email_event["html"],
|
html=email_event["html"],
|
||||||
|
@ -1446,7 +1448,6 @@ Model Info:
|
||||||
response_s: timedelta = end_time - start_time
|
response_s: timedelta = end_time - start_time
|
||||||
|
|
||||||
final_value = response_s
|
final_value = response_s
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if isinstance(response_obj, litellm.ModelResponse) and (
|
if isinstance(response_obj, litellm.ModelResponse) and (
|
||||||
hasattr(response_obj, "usage")
|
hasattr(response_obj, "usage")
|
||||||
|
@ -1505,7 +1506,7 @@ Model Info:
|
||||||
await self.region_outage_alerts(
|
await self.region_outage_alerts(
|
||||||
exception=kwargs["exception"], deployment_id=model_id
|
exception=kwargs["exception"], deployment_id=model_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _run_scheduler_helper(self, llm_router) -> bool:
|
async def _run_scheduler_helper(self, llm_router) -> bool:
|
||||||
|
|
|
@ -35,7 +35,7 @@ class LiteLLMBase(BaseModel):
|
||||||
def json(self, **kwargs): # type: ignore
|
def json(self, **kwargs): # type: ignore
|
||||||
try:
|
try:
|
||||||
return self.model_dump() # noqa
|
return self.model_dump() # noqa
|
||||||
except:
|
except Exception:
|
||||||
# if using pydantic v1
|
# if using pydantic v1
|
||||||
return self.dict()
|
return self.dict()
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to aispend.io
|
# On success + failure, log events to aispend.io
|
||||||
import dotenv, os
|
|
||||||
import traceback
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
|
||||||
model_cost = {
|
model_cost = {
|
||||||
"gpt-3.5-turbo": {
|
"gpt-3.5-turbo": {
|
||||||
|
@ -118,8 +120,6 @@ class AISpendLogger:
|
||||||
for model in model_cost:
|
for model in model_cost:
|
||||||
input_cost_sum += model_cost[model]["input_cost_per_token"]
|
input_cost_sum += model_cost[model]["input_cost_per_token"]
|
||||||
output_cost_sum += model_cost[model]["output_cost_per_token"]
|
output_cost_sum += model_cost[model]["output_cost_per_token"]
|
||||||
avg_input_cost = input_cost_sum / len(model_cost.keys())
|
|
||||||
avg_output_cost = output_cost_sum / len(model_cost.keys())
|
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost[model]["input_cost_per_token"]
|
model_cost[model]["input_cost_per_token"]
|
||||||
* response_obj["usage"]["prompt_tokens"]
|
* response_obj["usage"]["prompt_tokens"]
|
||||||
|
@ -137,12 +137,6 @@ class AISpendLogger:
|
||||||
f"AISpend Logging - Enters logging function for model {model}"
|
f"AISpend Logging - Enters logging function for model {model}"
|
||||||
)
|
)
|
||||||
|
|
||||||
url = f"https://aispend.io/api/v1/accounts/{self.account_id}/data"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
response_timestamp = datetime.datetime.fromtimestamp(
|
response_timestamp = datetime.datetime.fromtimestamp(
|
||||||
int(response_obj["created"])
|
int(response_obj["created"])
|
||||||
).strftime("%Y-%m-%d")
|
).strftime("%Y-%m-%d")
|
||||||
|
@ -168,6 +162,6 @@ class AISpendLogger:
|
||||||
]
|
]
|
||||||
|
|
||||||
print_verbose(f"AISpend Logging - final data object: {data}")
|
print_verbose(f"AISpend Logging - final data object: {data}")
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
|
print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -23,7 +23,7 @@ def set_arize_ai_attributes(span: Span, kwargs, response_obj):
|
||||||
)
|
)
|
||||||
|
|
||||||
optional_params = kwargs.get("optional_params", {})
|
optional_params = kwargs.get("optional_params", {})
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
# litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
|
|
||||||
#############################################
|
#############################################
|
||||||
############ LLM CALL METADATA ##############
|
############ LLM CALL METADATA ##############
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
class AthinaLogger:
|
class AthinaLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
import os
|
import os
|
||||||
|
@ -23,17 +24,20 @@ class AthinaLogger:
|
||||||
]
|
]
|
||||||
|
|
||||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||||
import requests # type: ignore
|
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_stream = kwargs.get("stream", False)
|
is_stream = kwargs.get("stream", False)
|
||||||
if is_stream:
|
if is_stream:
|
||||||
if "complete_streaming_response" in kwargs:
|
if "complete_streaming_response" in kwargs:
|
||||||
# Log the completion response in streaming mode
|
# Log the completion response in streaming mode
|
||||||
completion_response = kwargs["complete_streaming_response"]
|
completion_response = kwargs["complete_streaming_response"]
|
||||||
response_json = completion_response.model_dump() if completion_response else {}
|
response_json = (
|
||||||
|
completion_response.model_dump() if completion_response else {}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Skip logging if the completion response is not available
|
# Skip logging if the completion response is not available
|
||||||
return
|
return
|
||||||
|
@ -52,8 +56,8 @@ class AthinaLogger:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
type(end_time) == datetime.datetime
|
type(end_time) is datetime.datetime
|
||||||
and type(start_time) == datetime.datetime
|
and type(start_time) is datetime.datetime
|
||||||
):
|
):
|
||||||
data["response_time"] = int(
|
data["response_time"] = int(
|
||||||
(end_time - start_time).total_seconds() * 1000
|
(end_time - start_time).total_seconds() * 1000
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to aispend.io
|
# On success + failure, log events to aispend.io
|
||||||
import dotenv, os
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import traceback
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
model_cost = {
|
model_cost = {
|
||||||
"gpt-3.5-turbo": {
|
"gpt-3.5-turbo": {
|
||||||
|
@ -92,91 +93,12 @@ class BerriSpendLogger:
|
||||||
self.account_id = os.getenv("BERRISPEND_ACCOUNT_ID")
|
self.account_id = os.getenv("BERRISPEND_ACCOUNT_ID")
|
||||||
|
|
||||||
def price_calculator(self, model, response_obj, start_time, end_time):
|
def price_calculator(self, model, response_obj, start_time, end_time):
|
||||||
# try and find if the model is in the model_cost map
|
return
|
||||||
# else default to the average of the costs
|
|
||||||
prompt_tokens_cost_usd_dollar = 0
|
|
||||||
completion_tokens_cost_usd_dollar = 0
|
|
||||||
if model in model_cost:
|
|
||||||
prompt_tokens_cost_usd_dollar = (
|
|
||||||
model_cost[model]["input_cost_per_token"]
|
|
||||||
* response_obj["usage"]["prompt_tokens"]
|
|
||||||
)
|
|
||||||
completion_tokens_cost_usd_dollar = (
|
|
||||||
model_cost[model]["output_cost_per_token"]
|
|
||||||
* response_obj["usage"]["completion_tokens"]
|
|
||||||
)
|
|
||||||
elif "replicate" in model:
|
|
||||||
# replicate models are charged based on time
|
|
||||||
# llama 2 runs on an nvidia a100 which costs $0.0032 per second - https://replicate.com/replicate/llama-2-70b-chat
|
|
||||||
model_run_time = end_time - start_time # assuming time in seconds
|
|
||||||
cost_usd_dollar = model_run_time * 0.0032
|
|
||||||
prompt_tokens_cost_usd_dollar = cost_usd_dollar / 2
|
|
||||||
completion_tokens_cost_usd_dollar = cost_usd_dollar / 2
|
|
||||||
else:
|
|
||||||
# calculate average input cost
|
|
||||||
input_cost_sum = 0
|
|
||||||
output_cost_sum = 0
|
|
||||||
for model in model_cost:
|
|
||||||
input_cost_sum += model_cost[model]["input_cost_per_token"]
|
|
||||||
output_cost_sum += model_cost[model]["output_cost_per_token"]
|
|
||||||
avg_input_cost = input_cost_sum / len(model_cost.keys())
|
|
||||||
avg_output_cost = output_cost_sum / len(model_cost.keys())
|
|
||||||
prompt_tokens_cost_usd_dollar = (
|
|
||||||
model_cost[model]["input_cost_per_token"]
|
|
||||||
* response_obj["usage"]["prompt_tokens"]
|
|
||||||
)
|
|
||||||
completion_tokens_cost_usd_dollar = (
|
|
||||||
model_cost[model]["output_cost_per_token"]
|
|
||||||
* response_obj["usage"]["completion_tokens"]
|
|
||||||
)
|
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
|
||||||
|
|
||||||
def log_event(
|
def log_event(
|
||||||
self, model, messages, response_obj, start_time, end_time, print_verbose
|
self, model, messages, response_obj, start_time, end_time, print_verbose
|
||||||
):
|
):
|
||||||
# Method definition
|
"""
|
||||||
try:
|
This integration is not implemented yet.
|
||||||
print_verbose(
|
"""
|
||||||
f"BerriSpend Logging - Enters logging function for model {model}"
|
return
|
||||||
)
|
|
||||||
|
|
||||||
url = f"https://berrispend.berri.ai/spend"
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
(
|
|
||||||
prompt_tokens_cost_usd_dollar,
|
|
||||||
completion_tokens_cost_usd_dollar,
|
|
||||||
) = self.price_calculator(model, response_obj, start_time, end_time)
|
|
||||||
total_cost = (
|
|
||||||
prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
|
||||||
)
|
|
||||||
|
|
||||||
response_time = (end_time - start_time).total_seconds()
|
|
||||||
if "response" in response_obj:
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"response_time": response_time,
|
|
||||||
"model_id": response_obj["model"],
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"messages": messages,
|
|
||||||
"response": response_obj["choices"][0]["message"]["content"],
|
|
||||||
"account_id": self.account_id,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
elif "error" in response_obj:
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"response_time": response_time,
|
|
||||||
"model_id": response_obj["model"],
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"messages": messages,
|
|
||||||
"error": response_obj["error"],
|
|
||||||
"account_id": self.account_id,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
print_verbose(f"BerriSpend Logging - final data object: {data}")
|
|
||||||
response = requests.post(url, headers=headers, json=data)
|
|
||||||
except:
|
|
||||||
print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}")
|
|
||||||
pass
|
|
||||||
|
|
|
@ -136,27 +136,23 @@ class BraintrustLogger(CustomLogger):
|
||||||
project_id = self.default_project_id
|
project_id = self.default_project_id
|
||||||
|
|
||||||
prompt = {"messages": kwargs.get("messages")}
|
prompt = {"messages": kwargs.get("messages")}
|
||||||
|
output = None
|
||||||
if response_obj is not None and (
|
if response_obj is not None and (
|
||||||
kwargs.get("call_type", None) == "embedding"
|
kwargs.get("call_type", None) == "embedding"
|
||||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||||
):
|
):
|
||||||
input = prompt
|
|
||||||
output = None
|
output = None
|
||||||
elif response_obj is not None and isinstance(
|
elif response_obj is not None and isinstance(
|
||||||
response_obj, litellm.ModelResponse
|
response_obj, litellm.ModelResponse
|
||||||
):
|
):
|
||||||
input = prompt
|
|
||||||
output = response_obj["choices"][0]["message"].json()
|
output = response_obj["choices"][0]["message"].json()
|
||||||
elif response_obj is not None and isinstance(
|
elif response_obj is not None and isinstance(
|
||||||
response_obj, litellm.TextCompletionResponse
|
response_obj, litellm.TextCompletionResponse
|
||||||
):
|
):
|
||||||
input = prompt
|
|
||||||
output = response_obj.choices[0].text
|
output = response_obj.choices[0].text
|
||||||
elif response_obj is not None and isinstance(
|
elif response_obj is not None and isinstance(
|
||||||
response_obj, litellm.ImageResponse
|
response_obj, litellm.ImageResponse
|
||||||
):
|
):
|
||||||
input = prompt
|
|
||||||
output = response_obj["data"]
|
output = response_obj["data"]
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
@ -169,7 +165,7 @@ class BraintrustLogger(CustomLogger):
|
||||||
metadata = copy.deepcopy(
|
metadata = copy.deepcopy(
|
||||||
metadata
|
metadata
|
||||||
) # Avoid modifying the original metadata
|
) # Avoid modifying the original metadata
|
||||||
except:
|
except Exception:
|
||||||
new_metadata = {}
|
new_metadata = {}
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
if (
|
if (
|
||||||
|
@ -210,16 +206,13 @@ class BraintrustLogger(CustomLogger):
|
||||||
clean_metadata["litellm_response_cost"] = cost
|
clean_metadata["litellm_response_cost"] = cost
|
||||||
|
|
||||||
metrics: Optional[dict] = None
|
metrics: Optional[dict] = None
|
||||||
if (
|
usage_obj = getattr(response_obj, "usage", None)
|
||||||
response_obj is not None
|
if usage_obj and isinstance(usage_obj, litellm.Usage):
|
||||||
and hasattr(response_obj, "usage")
|
litellm.utils.get_logging_id(start_time, response_obj)
|
||||||
and isinstance(response_obj.usage, litellm.Usage)
|
|
||||||
):
|
|
||||||
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
|
||||||
metrics = {
|
metrics = {
|
||||||
"prompt_tokens": response_obj.usage.prompt_tokens,
|
"prompt_tokens": usage_obj.prompt_tokens,
|
||||||
"completion_tokens": response_obj.usage.completion_tokens,
|
"completion_tokens": usage_obj.completion_tokens,
|
||||||
"total_tokens": response_obj.usage.total_tokens,
|
"total_tokens": usage_obj.total_tokens,
|
||||||
"total_cost": cost,
|
"total_cost": cost,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -255,27 +248,23 @@ class BraintrustLogger(CustomLogger):
|
||||||
project_id = self.default_project_id
|
project_id = self.default_project_id
|
||||||
|
|
||||||
prompt = {"messages": kwargs.get("messages")}
|
prompt = {"messages": kwargs.get("messages")}
|
||||||
|
output = None
|
||||||
if response_obj is not None and (
|
if response_obj is not None and (
|
||||||
kwargs.get("call_type", None) == "embedding"
|
kwargs.get("call_type", None) == "embedding"
|
||||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||||
):
|
):
|
||||||
input = prompt
|
|
||||||
output = None
|
output = None
|
||||||
elif response_obj is not None and isinstance(
|
elif response_obj is not None and isinstance(
|
||||||
response_obj, litellm.ModelResponse
|
response_obj, litellm.ModelResponse
|
||||||
):
|
):
|
||||||
input = prompt
|
|
||||||
output = response_obj["choices"][0]["message"].json()
|
output = response_obj["choices"][0]["message"].json()
|
||||||
elif response_obj is not None and isinstance(
|
elif response_obj is not None and isinstance(
|
||||||
response_obj, litellm.TextCompletionResponse
|
response_obj, litellm.TextCompletionResponse
|
||||||
):
|
):
|
||||||
input = prompt
|
|
||||||
output = response_obj.choices[0].text
|
output = response_obj.choices[0].text
|
||||||
elif response_obj is not None and isinstance(
|
elif response_obj is not None and isinstance(
|
||||||
response_obj, litellm.ImageResponse
|
response_obj, litellm.ImageResponse
|
||||||
):
|
):
|
||||||
input = prompt
|
|
||||||
output = response_obj["data"]
|
output = response_obj["data"]
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
@ -331,16 +320,13 @@ class BraintrustLogger(CustomLogger):
|
||||||
clean_metadata["litellm_response_cost"] = cost
|
clean_metadata["litellm_response_cost"] = cost
|
||||||
|
|
||||||
metrics: Optional[dict] = None
|
metrics: Optional[dict] = None
|
||||||
if (
|
usage_obj = getattr(response_obj, "usage", None)
|
||||||
response_obj is not None
|
if usage_obj and isinstance(usage_obj, litellm.Usage):
|
||||||
and hasattr(response_obj, "usage")
|
litellm.utils.get_logging_id(start_time, response_obj)
|
||||||
and isinstance(response_obj.usage, litellm.Usage)
|
|
||||||
):
|
|
||||||
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
|
||||||
metrics = {
|
metrics = {
|
||||||
"prompt_tokens": response_obj.usage.prompt_tokens,
|
"prompt_tokens": usage_obj.prompt_tokens,
|
||||||
"completion_tokens": response_obj.usage.completion_tokens,
|
"completion_tokens": usage_obj.completion_tokens,
|
||||||
"total_tokens": response_obj.usage.total_tokens,
|
"total_tokens": usage_obj.total_tokens,
|
||||||
"total_cost": cost,
|
"total_cost": cost,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,25 +2,24 @@
|
||||||
|
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import datetime
|
||||||
|
import json
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
import os
|
||||||
from litellm.caching import DualCache
|
|
||||||
|
|
||||||
from typing import Literal, Union
|
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
|
||||||
import requests
|
|
||||||
import traceback
|
|
||||||
import datetime, subprocess, sys
|
|
||||||
import litellm, uuid
|
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
|
||||||
|
|
||||||
|
|
||||||
def create_client():
|
def create_client():
|
||||||
try:
|
try:
|
||||||
|
@ -260,18 +259,12 @@ class ClickhouseLogger:
|
||||||
f"ClickhouseLogger Logging - Enters logging function for model {kwargs}"
|
f"ClickhouseLogger Logging - Enters logging function for model {kwargs}"
|
||||||
)
|
)
|
||||||
# follows the same params as langfuse.py
|
# follows the same params as langfuse.py
|
||||||
from litellm.proxy.utils import get_logging_payload
|
|
||||||
|
|
||||||
payload = get_logging_payload(
|
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
kwargs=kwargs,
|
"standard_logging_object"
|
||||||
response_obj=response_obj,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
)
|
||||||
metadata = payload.get("metadata", "") or ""
|
if payload is None:
|
||||||
request_tags = payload.get("request_tags", "") or ""
|
return
|
||||||
payload["metadata"] = str(metadata)
|
|
||||||
payload["request_tags"] = str(request_tags)
|
|
||||||
# Build the initial payload
|
# Build the initial payload
|
||||||
|
|
||||||
verbose_logger.debug(f"\nClickhouse Logger - Logging payload = {payload}")
|
verbose_logger.debug(f"\nClickhouse Logger - Logging payload = {payload}")
|
||||||
|
|
|
@ -12,7 +12,12 @@ from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.types.llms.openai import ChatCompletionRequest
|
from litellm.types.llms.openai import ChatCompletionRequest
|
||||||
from litellm.types.services import ServiceLoggerPayload
|
from litellm.types.services import ServiceLoggerPayload
|
||||||
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
|
from litellm.types.utils import (
|
||||||
|
AdapterCompletionStreamWrapper,
|
||||||
|
EmbeddingResponse,
|
||||||
|
ImageResponse,
|
||||||
|
ModelResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
|
@ -140,8 +145,8 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response,
|
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
||||||
):
|
) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_logging_hook(
|
async def async_logging_hook(
|
||||||
|
@ -188,7 +193,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
|
|
||||||
async def async_log_input_event(
|
async def async_log_input_event(
|
||||||
|
@ -202,7 +207,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
|
|
||||||
def log_event(
|
def log_event(
|
||||||
|
@ -217,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
start_time,
|
start_time,
|
||||||
end_time,
|
end_time,
|
||||||
)
|
)
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -233,6 +238,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
start_time,
|
start_time,
|
||||||
end_time,
|
end_time,
|
||||||
)
|
)
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -54,7 +54,7 @@ class DataDogLogger(CustomBatchLogger):
|
||||||
`DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
|
`DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(f"Datadog: in init datadog logger")
|
verbose_logger.debug("Datadog: in init datadog logger")
|
||||||
# check if the correct env variables are set
|
# check if the correct env variables are set
|
||||||
if os.getenv("DD_API_KEY", None) is None:
|
if os.getenv("DD_API_KEY", None) is None:
|
||||||
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
|
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
|
||||||
|
@ -245,12 +245,12 @@ class DataDogLogger(CustomBatchLogger):
|
||||||
usage = dict(usage)
|
usage = dict(usage)
|
||||||
try:
|
try:
|
||||||
response_time = (end_time - start_time).total_seconds() * 1000
|
response_time = (end_time - start_time).total_seconds() * 1000
|
||||||
except:
|
except Exception:
|
||||||
response_time = None
|
response_time = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_obj = dict(response_obj)
|
response_obj = dict(response_obj)
|
||||||
except:
|
except Exception:
|
||||||
response_obj = response_obj
|
response_obj = response_obj
|
||||||
|
|
||||||
# Clean Metadata before logging - never log raw metadata
|
# Clean Metadata before logging - never log raw metadata
|
||||||
|
|
|
@ -7,7 +7,7 @@ def make_json_serializable(payload):
|
||||||
elif not isinstance(value, (str, int, float, bool, type(None))):
|
elif not isinstance(value, (str, int, float, bool, type(None))):
|
||||||
# everything else becomes a string
|
# everything else becomes a string
|
||||||
payload[key] = str(value)
|
payload[key] = str(value)
|
||||||
except:
|
except Exception:
|
||||||
# non blocking if it can't cast to a str
|
# non blocking if it can't cast to a str
|
||||||
pass
|
pass
|
||||||
return payload
|
return payload
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import datetime
|
||||||
import requests # type: ignore
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import uuid
|
||||||
import litellm, uuid
|
from typing import Any
|
||||||
from litellm._logging import print_verbose
|
|
||||||
|
import dotenv
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
class DyanmoDBLogger:
|
class DyanmoDBLogger:
|
||||||
|
@ -16,7 +20,7 @@ class DyanmoDBLogger:
|
||||||
# Instance variables
|
# Instance variables
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
self.dynamodb = boto3.resource(
|
self.dynamodb: Any = boto3.resource(
|
||||||
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
|
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
|
||||||
)
|
)
|
||||||
if litellm.dynamodb_table_name is None:
|
if litellm.dynamodb_table_name is None:
|
||||||
|
@ -67,7 +71,7 @@ class DyanmoDBLogger:
|
||||||
for key, value in payload.items():
|
for key, value in payload.items():
|
||||||
try:
|
try:
|
||||||
payload[key] = str(value)
|
payload[key] = str(value)
|
||||||
except:
|
except Exception:
|
||||||
# non blocking if it can't cast to a str
|
# non blocking if it can't cast to a str
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -84,6 +88,6 @@ class DyanmoDBLogger:
|
||||||
f"DynamoDB Layer Logging - final response object: {response_obj}"
|
f"DynamoDB Layer Logging - final response object: {response_obj}"
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -9,6 +9,7 @@ import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
httpxSpecialProvider,
|
httpxSpecialProvider,
|
||||||
)
|
)
|
||||||
|
@ -41,7 +42,7 @@ class GalileoObserve(CustomLogger):
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.base_url = os.getenv("GALILEO_BASE_URL", None)
|
self.base_url = os.getenv("GALILEO_BASE_URL", None)
|
||||||
self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
|
self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
|
||||||
self.headers = None
|
self.headers: Optional[Dict[str, str]] = None
|
||||||
self.async_httpx_handler = get_async_httpx_client(
|
self.async_httpx_handler = get_async_httpx_client(
|
||||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||||
)
|
)
|
||||||
|
@ -54,7 +55,7 @@ class GalileoObserve(CustomLogger):
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
}
|
}
|
||||||
galileo_login_response = self.async_httpx_handler.post(
|
galileo_login_response = litellm.module_level_client.post(
|
||||||
url=f"{self.base_url}/login",
|
url=f"{self.base_url}/login",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data={
|
data={
|
||||||
|
@ -94,13 +95,9 @@ class GalileoObserve(CustomLogger):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def async_log_success_event(
|
async def async_log_success_event(
|
||||||
self,
|
self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any
|
||||||
kwargs,
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
response_obj,
|
|
||||||
):
|
):
|
||||||
verbose_logger.debug(f"On Async Success")
|
verbose_logger.debug("On Async Success")
|
||||||
|
|
||||||
_latency_ms = int((end_time - start_time).total_seconds() * 1000)
|
_latency_ms = int((end_time - start_time).total_seconds() * 1000)
|
||||||
_call_type = kwargs.get("call_type", "litellm")
|
_call_type = kwargs.get("call_type", "litellm")
|
||||||
|
@ -116,26 +113,27 @@ class GalileoObserve(CustomLogger):
|
||||||
response_obj=response_obj, kwargs=kwargs
|
response_obj=response_obj, kwargs=kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
request_record = LLMResponse(
|
if output_text is not None:
|
||||||
latency_ms=_latency_ms,
|
request_record = LLMResponse(
|
||||||
status_code=200,
|
latency_ms=_latency_ms,
|
||||||
input_text=input_text,
|
status_code=200,
|
||||||
output_text=output_text,
|
input_text=input_text,
|
||||||
node_type=_call_type,
|
output_text=output_text,
|
||||||
model=kwargs.get("model", "-"),
|
node_type=_call_type,
|
||||||
num_input_tokens=num_input_tokens,
|
model=kwargs.get("model", "-"),
|
||||||
num_output_tokens=num_output_tokens,
|
num_input_tokens=num_input_tokens,
|
||||||
created_at=start_time.strftime(
|
num_output_tokens=num_output_tokens,
|
||||||
"%Y-%m-%dT%H:%M:%S"
|
created_at=start_time.strftime(
|
||||||
), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format
|
"%Y-%m-%dT%H:%M:%S"
|
||||||
)
|
), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format
|
||||||
|
)
|
||||||
|
|
||||||
# dump to dict
|
# dump to dict
|
||||||
request_dict = request_record.model_dump()
|
request_dict = request_record.model_dump()
|
||||||
self.in_memory_records.append(request_dict)
|
self.in_memory_records.append(request_dict)
|
||||||
|
|
||||||
if len(self.in_memory_records) >= self.batch_size:
|
if len(self.in_memory_records) >= self.batch_size:
|
||||||
await self.flush_in_memory_records()
|
await self.flush_in_memory_records()
|
||||||
|
|
||||||
async def flush_in_memory_records(self):
|
async def flush_in_memory_records(self):
|
||||||
verbose_logger.debug("flushing in memory records")
|
verbose_logger.debug("flushing in memory records")
|
||||||
|
@ -159,4 +157,4 @@ class GalileoObserve(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
verbose_logger.debug(f"On Async Failure")
|
verbose_logger.debug("On Async Failure")
|
||||||
|
|
|
@ -56,8 +56,8 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
response_obj,
|
response_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
|
start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
|
end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
headers = await self.construct_request_headers()
|
headers = await self.construct_request_headers()
|
||||||
|
|
||||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
|
@ -103,8 +103,8 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
response_obj,
|
response_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
|
start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
|
end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
headers = await self.construct_request_headers()
|
headers = await self.construct_request_headers()
|
||||||
|
|
||||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import requests # type: ignore
|
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class GreenscaleLogger:
|
class GreenscaleLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -29,7 +30,7 @@ class GreenscaleLogger:
|
||||||
"%Y-%m-%dT%H:%M:%SZ"
|
"%Y-%m-%dT%H:%M:%SZ"
|
||||||
)
|
)
|
||||||
|
|
||||||
if type(end_time) == datetime and type(start_time) == datetime:
|
if type(end_time) is datetime and type(start_time) is datetime:
|
||||||
data["invocationLatency"] = int(
|
data["invocationLatency"] = int(
|
||||||
(end_time - start_time).total_seconds() * 1000
|
(end_time - start_time).total_seconds() * 1000
|
||||||
)
|
)
|
||||||
|
@ -50,6 +51,9 @@ class GreenscaleLogger:
|
||||||
|
|
||||||
data["tags"] = tags
|
data["tags"] = tags
|
||||||
|
|
||||||
|
if self.greenscale_logging_url is None:
|
||||||
|
raise Exception("Greenscale Logger Error - No logging URL found")
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.greenscale_logging_url,
|
self.greenscale_logging_url,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
|
|
|
@ -1,15 +1,28 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Helicone
|
# On success, logs events to Helicone
|
||||||
import dotenv, os
|
import os
|
||||||
import requests # type: ignore
|
|
||||||
import litellm
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
|
||||||
class HeliconeLogger:
|
class HeliconeLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
helicone_model_list = ["gpt", "claude", "command-r", "command-r-plus", "command-light", "command-medium", "command-medium-beta", "command-xlarge-nightly", "command-nightly"]
|
helicone_model_list = [
|
||||||
|
"gpt",
|
||||||
|
"claude",
|
||||||
|
"command-r",
|
||||||
|
"command-r-plus",
|
||||||
|
"command-light",
|
||||||
|
"command-medium",
|
||||||
|
"command-medium-beta",
|
||||||
|
"command-xlarge-nightly",
|
||||||
|
"command-nightly",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Instance variables
|
# Instance variables
|
||||||
|
@ -17,7 +30,7 @@ class HeliconeLogger:
|
||||||
self.key = os.getenv("HELICONE_API_KEY")
|
self.key = os.getenv("HELICONE_API_KEY")
|
||||||
|
|
||||||
def claude_mapping(self, model, messages, response_obj):
|
def claude_mapping(self, model, messages, response_obj):
|
||||||
from anthropic import HUMAN_PROMPT, AI_PROMPT
|
from anthropic import AI_PROMPT, HUMAN_PROMPT
|
||||||
|
|
||||||
prompt = f"{HUMAN_PROMPT}"
|
prompt = f"{HUMAN_PROMPT}"
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
@ -29,7 +42,6 @@ class HeliconeLogger:
|
||||||
else:
|
else:
|
||||||
prompt += f"{HUMAN_PROMPT}{message['content']}"
|
prompt += f"{HUMAN_PROMPT}{message['content']}"
|
||||||
prompt += f"{AI_PROMPT}"
|
prompt += f"{AI_PROMPT}"
|
||||||
claude_provider_request = {"model": model, "prompt": prompt}
|
|
||||||
|
|
||||||
choice = response_obj["choices"][0]
|
choice = response_obj["choices"][0]
|
||||||
message = choice["message"]
|
message = choice["message"]
|
||||||
|
@ -37,12 +49,14 @@ class HeliconeLogger:
|
||||||
content = []
|
content = []
|
||||||
if "tool_calls" in message and message["tool_calls"]:
|
if "tool_calls" in message and message["tool_calls"]:
|
||||||
for tool_call in message["tool_calls"]:
|
for tool_call in message["tool_calls"]:
|
||||||
content.append({
|
content.append(
|
||||||
"type": "tool_use",
|
{
|
||||||
"id": tool_call["id"],
|
"type": "tool_use",
|
||||||
"name": tool_call["function"]["name"],
|
"id": tool_call["id"],
|
||||||
"input": tool_call["function"]["arguments"]
|
"name": tool_call["function"]["name"],
|
||||||
})
|
"input": tool_call["function"]["arguments"],
|
||||||
|
}
|
||||||
|
)
|
||||||
elif "content" in message and message["content"]:
|
elif "content" in message and message["content"]:
|
||||||
content = [{"type": "text", "text": message["content"]}]
|
content = [{"type": "text", "text": message["content"]}]
|
||||||
|
|
||||||
|
@ -56,8 +70,8 @@ class HeliconeLogger:
|
||||||
"stop_sequence": None,
|
"stop_sequence": None,
|
||||||
"usage": {
|
"usage": {
|
||||||
"input_tokens": response_obj["usage"]["prompt_tokens"],
|
"input_tokens": response_obj["usage"]["prompt_tokens"],
|
||||||
"output_tokens": response_obj["usage"]["completion_tokens"]
|
"output_tokens": response_obj["usage"]["completion_tokens"],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return claude_response_obj
|
return claude_response_obj
|
||||||
|
@ -99,10 +113,8 @@ class HeliconeLogger:
|
||||||
f"Helicone Logging - Enters logging function for model {model}"
|
f"Helicone Logging - Enters logging function for model {model}"
|
||||||
)
|
)
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
kwargs.get("litellm_call_id", None)
|
||||||
metadata = (
|
metadata = litellm_params.get("metadata", {}) or {}
|
||||||
litellm_params.get("metadata", {}) or {}
|
|
||||||
)
|
|
||||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||||
model = (
|
model = (
|
||||||
model
|
model
|
||||||
|
@ -175,6 +187,6 @@ class HeliconeLogger:
|
||||||
f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
|
f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
|
||||||
)
|
)
|
||||||
print_verbose(f"Helicone Logging - Error {response.text}")
|
print_verbose(f"Helicone Logging - Error {response.text}")
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
|
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -11,7 +11,7 @@ import dotenv
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
@ -65,8 +65,8 @@ class LagoLogger(CustomLogger):
|
||||||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||||
|
|
||||||
def _common_logic(self, kwargs: dict, response_obj) -> dict:
|
def _common_logic(self, kwargs: dict, response_obj) -> dict:
|
||||||
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||||
dt = get_utc_datetime().isoformat()
|
get_utc_datetime().isoformat()
|
||||||
cost = kwargs.get("response_cost", None)
|
cost = kwargs.get("response_cost", None)
|
||||||
model = kwargs.get("model")
|
model = kwargs.get("model")
|
||||||
usage = {}
|
usage = {}
|
||||||
|
@ -86,7 +86,7 @@ class LagoLogger(CustomLogger):
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||||
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
|
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
|
||||||
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
|
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
|
||||||
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
|
litellm_params["metadata"].get("user_api_key_org_id", None)
|
||||||
|
|
||||||
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
|
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
|
||||||
external_customer_id: Optional[str] = None
|
external_customer_id: Optional[str] = None
|
||||||
|
@ -158,8 +158,9 @@ class LagoLogger(CustomLogger):
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(response, "text"):
|
error_response = getattr(e, "response", None)
|
||||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
if error_response is not None and hasattr(error_response, "text"):
|
||||||
|
verbose_logger.debug(f"\nError Message: {error_response.text}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
@ -199,5 +200,5 @@ class LagoLogger(CustomLogger):
|
||||||
verbose_logger.debug(f"Logged Lago Object: {response.text}")
|
verbose_logger.debug(f"Logged Lago Object: {response.text}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if response is not None and hasattr(response, "text"):
|
if response is not None and hasattr(response, "text"):
|
||||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
verbose_logger.debug(f"\nError Message: {response.text}")
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -67,7 +67,7 @@ class LangFuseLogger:
|
||||||
try:
|
try:
|
||||||
project_id = self.Langfuse.client.projects.get().data[0].id
|
project_id = self.Langfuse.client.projects.get().data[0].id
|
||||||
os.environ["LANGFUSE_PROJECT_ID"] = project_id
|
os.environ["LANGFUSE_PROJECT_ID"] = project_id
|
||||||
except:
|
except Exception:
|
||||||
project_id = None
|
project_id = None
|
||||||
|
|
||||||
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
|
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
|
||||||
|
@ -184,7 +184,7 @@ class LangFuseLogger:
|
||||||
if not isinstance(value, (str, int, bool, float)):
|
if not isinstance(value, (str, int, bool, float)):
|
||||||
try:
|
try:
|
||||||
optional_params[param] = str(value)
|
optional_params[param] = str(value)
|
||||||
except:
|
except Exception:
|
||||||
# if casting value to str fails don't block logging
|
# if casting value to str fails don't block logging
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ class LangFuseLogger:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Langfuse Layer Logging - final response object: {response_obj}"
|
f"Langfuse Layer Logging - final response object: {response_obj}"
|
||||||
)
|
)
|
||||||
verbose_logger.info(f"Langfuse Layer Logging - logging success")
|
verbose_logger.info("Langfuse Layer Logging - logging success")
|
||||||
|
|
||||||
return {"trace_id": trace_id, "generation_id": generation_id}
|
return {"trace_id": trace_id, "generation_id": generation_id}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -492,7 +492,7 @@ class LangFuseLogger:
|
||||||
output if not mask_output else "redacted-by-litellm"
|
output if not mask_output else "redacted-by-litellm"
|
||||||
)
|
)
|
||||||
|
|
||||||
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
|
if debug is True or (isinstance(debug, str) and debug.lower() == "true"):
|
||||||
if "metadata" in trace_params:
|
if "metadata" in trace_params:
|
||||||
# log the raw_metadata in the trace
|
# log the raw_metadata in the trace
|
||||||
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
|
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
|
||||||
|
@ -535,8 +535,8 @@ class LangFuseLogger:
|
||||||
|
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request", None)
|
proxy_server_request = litellm_params.get("proxy_server_request", None)
|
||||||
if proxy_server_request:
|
if proxy_server_request:
|
||||||
method = proxy_server_request.get("method", None)
|
proxy_server_request.get("method", None)
|
||||||
url = proxy_server_request.get("url", None)
|
proxy_server_request.get("url", None)
|
||||||
headers = proxy_server_request.get("headers", None)
|
headers = proxy_server_request.get("headers", None)
|
||||||
clean_headers = {}
|
clean_headers = {}
|
||||||
if headers:
|
if headers:
|
||||||
|
@ -625,7 +625,7 @@ class LangFuseLogger:
|
||||||
generation_client = trace.generation(**generation_params)
|
generation_client = trace.generation(**generation_params)
|
||||||
|
|
||||||
return generation_client.trace_id, generation_id
|
return generation_client.trace_id, generation_id
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
|
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
|
@ -404,7 +404,7 @@ class LangsmithLogger(CustomBatchLogger):
|
||||||
verbose_logger.exception(
|
verbose_logger.exception(
|
||||||
f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}"
|
f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.exception(
|
verbose_logger.exception(
|
||||||
f"Langsmith Layer Error - {traceback.format_exc()}"
|
f"Langsmith Layer Error - {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
import requests, traceback, json, os
|
import json
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
import types
|
import types
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class LiteDebugger:
|
class LiteDebugger:
|
||||||
user_email = None
|
user_email = None
|
||||||
|
@ -17,23 +21,17 @@ class LiteDebugger:
|
||||||
email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
|
email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.user_email == None
|
self.user_email is None
|
||||||
): # if users are trying to use_client=True but token not set
|
): # if users are trying to use_client=True but token not set
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
|
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
|
||||||
)
|
)
|
||||||
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
|
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
|
||||||
try:
|
if self.user_email is None:
|
||||||
print(
|
|
||||||
f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m"
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}")
|
|
||||||
if self.user_email == None:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
|
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
|
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
|
||||||
)
|
)
|
||||||
|
@ -49,123 +47,18 @@ class LiteDebugger:
|
||||||
litellm_params,
|
litellm_params,
|
||||||
optional_params,
|
optional_params,
|
||||||
):
|
):
|
||||||
print_verbose(
|
"""
|
||||||
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
|
This integration is not implemented yet.
|
||||||
)
|
"""
|
||||||
try:
|
return
|
||||||
print_verbose(
|
|
||||||
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove_key_value(dictionary, key):
|
|
||||||
new_dict = dictionary.copy() # Create a copy of the original dictionary
|
|
||||||
new_dict.pop(key) # Remove the specified key-value pair from the copy
|
|
||||||
return new_dict
|
|
||||||
|
|
||||||
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
|
|
||||||
|
|
||||||
if call_type == "embedding":
|
|
||||||
for (
|
|
||||||
message
|
|
||||||
) in (
|
|
||||||
messages
|
|
||||||
): # assuming the input is a list as required by the embedding function
|
|
||||||
litellm_data_obj = {
|
|
||||||
"model": model,
|
|
||||||
"messages": [{"role": "user", "content": message}],
|
|
||||||
"end_user": end_user,
|
|
||||||
"status": "initiated",
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"user_email": self.user_email,
|
|
||||||
"litellm_params": updated_litellm_params,
|
|
||||||
"optional_params": optional_params,
|
|
||||||
}
|
|
||||||
print_verbose(
|
|
||||||
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
|
|
||||||
)
|
|
||||||
response = requests.post(
|
|
||||||
url=self.api_url,
|
|
||||||
headers={"content-type": "application/json"},
|
|
||||||
data=json.dumps(litellm_data_obj),
|
|
||||||
)
|
|
||||||
print_verbose(f"LiteDebugger: embedding api response - {response.text}")
|
|
||||||
elif call_type == "completion":
|
|
||||||
litellm_data_obj = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages
|
|
||||||
if isinstance(messages, list)
|
|
||||||
else [{"role": "user", "content": messages}],
|
|
||||||
"end_user": end_user,
|
|
||||||
"status": "initiated",
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"user_email": self.user_email,
|
|
||||||
"litellm_params": updated_litellm_params,
|
|
||||||
"optional_params": optional_params,
|
|
||||||
}
|
|
||||||
print_verbose(
|
|
||||||
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
|
|
||||||
)
|
|
||||||
response = requests.post(
|
|
||||||
url=self.api_url,
|
|
||||||
headers={"content-type": "application/json"},
|
|
||||||
data=json.dumps(litellm_data_obj),
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"LiteDebugger: completion api response - {response.text}"
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
print_verbose(
|
|
||||||
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def post_call_log_event(
|
def post_call_log_event(
|
||||||
self, original_response, litellm_call_id, print_verbose, call_type, stream
|
self, original_response, litellm_call_id, print_verbose, call_type, stream
|
||||||
):
|
):
|
||||||
print_verbose(
|
"""
|
||||||
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
|
This integration is not implemented yet.
|
||||||
)
|
"""
|
||||||
try:
|
return
|
||||||
if call_type == "embedding":
|
|
||||||
litellm_data_obj = {
|
|
||||||
"status": "received",
|
|
||||||
"additional_details": {
|
|
||||||
"original_response": str(
|
|
||||||
original_response["data"][0]["embedding"][:5]
|
|
||||||
)
|
|
||||||
}, # don't store the entire vector
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"user_email": self.user_email,
|
|
||||||
}
|
|
||||||
elif call_type == "completion" and not stream:
|
|
||||||
litellm_data_obj = {
|
|
||||||
"status": "received",
|
|
||||||
"additional_details": {"original_response": original_response},
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"user_email": self.user_email,
|
|
||||||
}
|
|
||||||
elif call_type == "completion" and stream:
|
|
||||||
litellm_data_obj = {
|
|
||||||
"status": "received",
|
|
||||||
"additional_details": {
|
|
||||||
"original_response": "Streamed response"
|
|
||||||
if isinstance(original_response, types.GeneratorType)
|
|
||||||
else original_response
|
|
||||||
},
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"user_email": self.user_email,
|
|
||||||
}
|
|
||||||
print_verbose(f"litedebugger post-call data object - {litellm_data_obj}")
|
|
||||||
response = requests.post(
|
|
||||||
url=self.api_url,
|
|
||||||
headers={"content-type": "application/json"},
|
|
||||||
data=json.dumps(litellm_data_obj),
|
|
||||||
)
|
|
||||||
print_verbose(f"LiteDebugger: api response - {response.text}")
|
|
||||||
except:
|
|
||||||
print_verbose(
|
|
||||||
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def log_event(
|
def log_event(
|
||||||
self,
|
self,
|
||||||
|
@ -178,85 +71,7 @@ class LiteDebugger:
|
||||||
call_type,
|
call_type,
|
||||||
stream=False,
|
stream=False,
|
||||||
):
|
):
|
||||||
print_verbose(
|
"""
|
||||||
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
|
This integration is not implemented yet.
|
||||||
)
|
"""
|
||||||
try:
|
return
|
||||||
print_verbose(
|
|
||||||
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"
|
|
||||||
)
|
|
||||||
total_cost = 0 # [TODO] implement cost tracking
|
|
||||||
response_time = (end_time - start_time).total_seconds()
|
|
||||||
if call_type == "completion" and stream == False:
|
|
||||||
litellm_data_obj = {
|
|
||||||
"response_time": response_time,
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"response": response_obj["choices"][0]["message"]["content"],
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"status": "success",
|
|
||||||
}
|
|
||||||
print_verbose(
|
|
||||||
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
|
|
||||||
)
|
|
||||||
response = requests.post(
|
|
||||||
url=self.api_url,
|
|
||||||
headers={"content-type": "application/json"},
|
|
||||||
data=json.dumps(litellm_data_obj),
|
|
||||||
)
|
|
||||||
elif call_type == "embedding":
|
|
||||||
litellm_data_obj = {
|
|
||||||
"response_time": response_time,
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"response": str(response_obj["data"][0]["embedding"][:5]),
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"status": "success",
|
|
||||||
}
|
|
||||||
response = requests.post(
|
|
||||||
url=self.api_url,
|
|
||||||
headers={"content-type": "application/json"},
|
|
||||||
data=json.dumps(litellm_data_obj),
|
|
||||||
)
|
|
||||||
elif call_type == "completion" and stream == True:
|
|
||||||
if len(response_obj["content"]) > 0: # don't log the empty strings
|
|
||||||
litellm_data_obj = {
|
|
||||||
"response_time": response_time,
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"response": response_obj["content"],
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"status": "success",
|
|
||||||
}
|
|
||||||
print_verbose(
|
|
||||||
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
|
|
||||||
)
|
|
||||||
response = requests.post(
|
|
||||||
url=self.api_url,
|
|
||||||
headers={"content-type": "application/json"},
|
|
||||||
data=json.dumps(litellm_data_obj),
|
|
||||||
)
|
|
||||||
elif "error" in response_obj:
|
|
||||||
if "Unable to map your input to a model." in response_obj["error"]:
|
|
||||||
total_cost = 0
|
|
||||||
litellm_data_obj = {
|
|
||||||
"response_time": response_time,
|
|
||||||
"model": response_obj["model"],
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"error": response_obj["error"],
|
|
||||||
"end_user": end_user,
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"status": "failure",
|
|
||||||
"user_email": self.user_email,
|
|
||||||
}
|
|
||||||
print_verbose(
|
|
||||||
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
|
|
||||||
)
|
|
||||||
response = requests.post(
|
|
||||||
url=self.api_url,
|
|
||||||
headers={"content-type": "application/json"},
|
|
||||||
data=json.dumps(litellm_data_obj),
|
|
||||||
)
|
|
||||||
print_verbose(f"LiteDebugger: api response - {response.text}")
|
|
||||||
except:
|
|
||||||
print_verbose(
|
|
||||||
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ class LogfireLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(f"in init logfire logger")
|
verbose_logger.debug("in init logfire logger")
|
||||||
import logfire
|
import logfire
|
||||||
|
|
||||||
# only setting up logfire if we are sending to logfire
|
# only setting up logfire if we are sending to logfire
|
||||||
|
@ -116,7 +116,7 @@ class LogfireLogger:
|
||||||
id = response_obj.get("id", str(uuid.uuid4()))
|
id = response_obj.get("id", str(uuid.uuid4()))
|
||||||
try:
|
try:
|
||||||
response_time = (end_time - start_time).total_seconds()
|
response_time = (end_time - start_time).total_seconds()
|
||||||
except:
|
except Exception:
|
||||||
response_time = None
|
response_time = None
|
||||||
|
|
||||||
# Clean Metadata before logging - never log raw metadata
|
# Clean Metadata before logging - never log raw metadata
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to lunary.ai
|
# On success + failure, log events to lunary.ai
|
||||||
from datetime import datetime, timezone
|
|
||||||
import traceback
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
|
|
||||||
|
@ -74,9 +74,9 @@ class LunaryLogger:
|
||||||
try:
|
try:
|
||||||
import lunary
|
import lunary
|
||||||
|
|
||||||
version = importlib.metadata.version("lunary")
|
version = importlib.metadata.version("lunary") # type: ignore
|
||||||
# if version < 0.1.43 then raise ImportError
|
# if version < 0.1.43 then raise ImportError
|
||||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore
|
||||||
print( # noqa
|
print( # noqa
|
||||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||||
)
|
)
|
||||||
|
@ -97,7 +97,7 @@ class LunaryLogger:
|
||||||
run_id,
|
run_id,
|
||||||
model,
|
model,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
extra=None,
|
extra={},
|
||||||
input=None,
|
input=None,
|
||||||
user_id=None,
|
user_id=None,
|
||||||
response_obj=None,
|
response_obj=None,
|
||||||
|
@ -128,7 +128,7 @@ class LunaryLogger:
|
||||||
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||||
try:
|
try:
|
||||||
extra[param] = str(value)
|
extra[param] = str(value)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if response_obj:
|
if response_obj:
|
||||||
|
@ -175,6 +175,6 @@ class LunaryLogger:
|
||||||
token_usage=usage,
|
token_usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
|
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -98,7 +98,7 @@ class OpenTelemetry(CustomLogger):
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logger = logging.getLogger(__name__)
|
logging.getLogger(__name__)
|
||||||
|
|
||||||
# Enable OpenTelemetry logging
|
# Enable OpenTelemetry logging
|
||||||
otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export")
|
otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export")
|
||||||
|
@ -520,7 +520,7 @@ class OpenTelemetry(CustomLogger):
|
||||||
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
|
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
|
||||||
from litellm.proxy._types import SpanAttributes
|
from litellm.proxy._types import SpanAttributes
|
||||||
|
|
||||||
optional_params = kwargs.get("optional_params", {})
|
kwargs.get("optional_params", {})
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
|
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
|
||||||
|
|
||||||
|
@ -769,6 +769,6 @@ class OpenTelemetry(CustomLogger):
|
||||||
management_endpoint_span.set_attribute(f"request.{key}", value)
|
management_endpoint_span.set_attribute(f"request.{key}", value)
|
||||||
|
|
||||||
_exception = logging_payload.exception
|
_exception = logging_payload.exception
|
||||||
management_endpoint_span.set_attribute(f"exception", str(_exception))
|
management_endpoint_span.set_attribute("exception", str(_exception))
|
||||||
management_endpoint_span.set_status(Status(StatusCode.ERROR))
|
management_endpoint_span.set_status(Status(StatusCode.ERROR))
|
||||||
management_endpoint_span.end(end_time=_end_time_ns)
|
management_endpoint_span.end(end_time=_end_time_ns)
|
||||||
|
|
|
@ -520,7 +520,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
user_api_team_alias = standard_logging_payload["metadata"][
|
user_api_team_alias = standard_logging_payload["metadata"][
|
||||||
"user_api_key_team_alias"
|
"user_api_key_team_alias"
|
||||||
]
|
]
|
||||||
exception = kwargs.get("exception", None)
|
kwargs.get("exception", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.litellm_llm_api_failed_requests_metric.labels(
|
self.litellm_llm_api_failed_requests_metric.labels(
|
||||||
|
@ -679,7 +679,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
).inc()
|
).inc()
|
||||||
|
|
||||||
pass
|
pass
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_llm_deployment_success_metrics(
|
def set_llm_deployment_success_metrics(
|
||||||
|
@ -800,7 +800,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
request_kwargs.get("stream", None) is not None
|
request_kwargs.get("stream", None) is not None
|
||||||
and request_kwargs["stream"] == True
|
and request_kwargs["stream"] is True
|
||||||
):
|
):
|
||||||
# only log ttft for streaming request
|
# only log ttft for streaming request
|
||||||
time_to_first_token_response_time = (
|
time_to_first_token_response_time = (
|
||||||
|
|
|
@ -3,11 +3,17 @@
|
||||||
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
|
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
|
||||||
|
|
||||||
|
|
||||||
import dotenv, os
|
import datetime
|
||||||
import requests # type: ignore
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import uuid
|
||||||
import litellm, uuid
|
|
||||||
|
import dotenv
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
|
|
||||||
|
@ -23,7 +29,7 @@ class PrometheusServicesLogger:
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
from prometheus_client import Counter, Histogram, REGISTRY
|
from prometheus_client import REGISTRY, Counter, Histogram
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Missing prometheus_client. Run `pip install prometheus-client`"
|
"Missing prometheus_client. Run `pip install prometheus-client`"
|
||||||
|
@ -33,7 +39,7 @@ class PrometheusServicesLogger:
|
||||||
self.Counter = Counter
|
self.Counter = Counter
|
||||||
self.REGISTRY = REGISTRY
|
self.REGISTRY = REGISTRY
|
||||||
|
|
||||||
verbose_logger.debug(f"in init prometheus services metrics")
|
verbose_logger.debug("in init prometheus services metrics")
|
||||||
|
|
||||||
self.services = [item.value for item in ServiceTypes]
|
self.services = [item.value for item in ServiceTypes]
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import os
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import dotenv
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import traceback
|
|
||||||
|
|
||||||
|
|
||||||
class PromptLayerLogger:
|
class PromptLayerLogger:
|
||||||
|
@ -84,6 +86,6 @@ class PromptLayerLogger:
|
||||||
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
|
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")
|
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import datetime
|
||||||
import requests # type: ignore
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
|
||||||
|
import dotenv
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +26,12 @@ class Supabase:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"])
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"])
|
||||||
import supabase
|
import supabase
|
||||||
self.supabase_client = supabase.create_client(
|
|
||||||
|
if self.supabase_url is None or self.supabase_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"LiteLLM Error, trying to use Supabase but url or key not passed. Create a table and set `litellm.supabase_url=<your-url>` and `litellm.supabase_key=<your-key>`"
|
||||||
|
)
|
||||||
|
self.supabase_client = supabase.create_client( # type: ignore
|
||||||
self.supabase_url, self.supabase_key
|
self.supabase_url, self.supabase_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,7 +55,7 @@ class Supabase:
|
||||||
.execute()
|
.execute()
|
||||||
)
|
)
|
||||||
print_verbose(f"data: {data}")
|
print_verbose(f"data: {data}")
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
|
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -109,6 +119,6 @@ class Supabase:
|
||||||
.execute()
|
.execute()
|
||||||
)
|
)
|
||||||
|
|
||||||
except:
|
except Exception:
|
||||||
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
|
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -167,18 +167,17 @@ try:
|
||||||
trace = self.results_to_trace_tree(request, response, results, time_elapsed)
|
trace = self.results_to_trace_tree(request, response, results, time_elapsed)
|
||||||
return trace
|
return trace
|
||||||
|
|
||||||
except:
|
except Exception:
|
||||||
imported_openAIResponse = False
|
imported_openAIResponse = False
|
||||||
|
|
||||||
|
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langfuse
|
# On success, logs events to Langfuse
|
||||||
import os
|
import os
|
||||||
import requests
|
import traceback
|
||||||
import requests
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import traceback
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class WeightsBiasesLogger:
|
class WeightsBiasesLogger:
|
||||||
|
@ -186,11 +185,11 @@ class WeightsBiasesLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
|
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
|
||||||
)
|
)
|
||||||
if imported_openAIResponse == False:
|
if imported_openAIResponse is False:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
|
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
|
||||||
)
|
)
|
||||||
|
@ -209,13 +208,14 @@ class WeightsBiasesLogger:
|
||||||
kwargs, response_obj, (end_time - start_time).total_seconds()
|
kwargs, response_obj, (end_time - start_time).total_seconds()
|
||||||
)
|
)
|
||||||
|
|
||||||
if trace is not None:
|
if trace is not None and run is not None:
|
||||||
run.log({"trace": trace})
|
run.log({"trace": trace})
|
||||||
|
|
||||||
run.finish()
|
if run is not None:
|
||||||
print_verbose(
|
run.finish()
|
||||||
f"W&B Logging Logging - final response object: {response_obj}"
|
print_verbose(
|
||||||
)
|
f"W&B Logging Logging - final response object: {response_obj}"
|
||||||
except:
|
)
|
||||||
|
except Exception:
|
||||||
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
|
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -62,7 +62,7 @@ def get_error_message(error_obj) -> Optional[str]:
|
||||||
|
|
||||||
# If all else fails, return None
|
# If all else fails, return None
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -910,7 +910,7 @@ def exception_type( # type: ignore
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
|
message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="sagemaker",
|
llm_provider="sagemaker",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
|
@ -1122,7 +1122,7 @@ def exception_type( # type: ignore
|
||||||
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
|
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"GeminiException - Invalid api key",
|
message="GeminiException - Invalid api key",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="palm",
|
llm_provider="palm",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
|
@ -2067,12 +2067,34 @@ def exception_logging(
|
||||||
logger_fn(
|
logger_fn(
|
||||||
model_call_details
|
model_call_details
|
||||||
) # Expectation: any logger function passed in by the user should accept a dict object
|
) # Expectation: any logger function passed in by the user should accept a dict object
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
|
||||||
|
"""
|
||||||
|
Internal helper function for litellm proxy
|
||||||
|
Add the Key Name + Team Name to the error
|
||||||
|
Only gets added if the metadata contains the user_api_key_alias and user_api_key_team_alias
|
||||||
|
|
||||||
|
[Non-Blocking helper function]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_api_key_name = metadata.get("user_api_key_alias", None)
|
||||||
|
_user_api_key_team_alias = metadata.get("user_api_key_team_alias", None)
|
||||||
|
if _api_key_name is not None:
|
||||||
|
request_info = (
|
||||||
|
f"\n\nKey Name: `{_api_key_name}`\nTeam: `{_user_api_key_team_alias}`"
|
||||||
|
+ request_info
|
||||||
|
)
|
||||||
|
|
||||||
|
return request_info
|
||||||
|
except Exception:
|
||||||
|
return request_info
|
||||||
|
|
|
@ -476,7 +476,7 @@ def get_llm_provider(
|
||||||
elif model == "*":
|
elif model == "*":
|
||||||
custom_llm_provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
if custom_llm_provider is None or custom_llm_provider == "":
|
if custom_llm_provider is None or custom_llm_provider == "":
|
||||||
if litellm.suppress_debug_info == False:
|
if litellm.suppress_debug_info is False:
|
||||||
print() # noqa
|
print() # noqa
|
||||||
print( # noqa
|
print( # noqa
|
||||||
"\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" # noqa
|
"\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" # noqa
|
||||||
|
|
|
@ -52,18 +52,8 @@ from litellm.types.utils import (
|
||||||
)
|
)
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
_get_base_model_from_metadata,
|
_get_base_model_from_metadata,
|
||||||
add_breadcrumb,
|
|
||||||
capture_exception,
|
|
||||||
customLogger,
|
|
||||||
liteDebuggerClient,
|
|
||||||
logfireLogger,
|
|
||||||
lunaryLogger,
|
|
||||||
print_verbose,
|
print_verbose,
|
||||||
prometheusLogger,
|
|
||||||
prompt_token_calculator,
|
prompt_token_calculator,
|
||||||
promptLayerLogger,
|
|
||||||
supabaseClient,
|
|
||||||
weightsBiasesLogger,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..integrations.aispend import AISpendLogger
|
from ..integrations.aispend import AISpendLogger
|
||||||
|
@ -71,7 +61,6 @@ from ..integrations.athina import AthinaLogger
|
||||||
from ..integrations.berrispend import BerriSpendLogger
|
from ..integrations.berrispend import BerriSpendLogger
|
||||||
from ..integrations.braintrust_logging import BraintrustLogger
|
from ..integrations.braintrust_logging import BraintrustLogger
|
||||||
from ..integrations.clickhouse import ClickhouseLogger
|
from ..integrations.clickhouse import ClickhouseLogger
|
||||||
from ..integrations.custom_logger import CustomLogger
|
|
||||||
from ..integrations.datadog.datadog import DataDogLogger
|
from ..integrations.datadog.datadog import DataDogLogger
|
||||||
from ..integrations.dynamodb import DyanmoDBLogger
|
from ..integrations.dynamodb import DyanmoDBLogger
|
||||||
from ..integrations.galileo import GalileoObserve
|
from ..integrations.galileo import GalileoObserve
|
||||||
|
@ -423,7 +412,7 @@ class Logging:
|
||||||
elif callback == "sentry" and add_breadcrumb:
|
elif callback == "sentry" and add_breadcrumb:
|
||||||
try:
|
try:
|
||||||
details_to_log = copy.deepcopy(self.model_call_details)
|
details_to_log = copy.deepcopy(self.model_call_details)
|
||||||
except:
|
except Exception:
|
||||||
details_to_log = self.model_call_details
|
details_to_log = self.model_call_details
|
||||||
if litellm.turn_off_message_logging:
|
if litellm.turn_off_message_logging:
|
||||||
# make a copy of the _model_Call_details and log it
|
# make a copy of the _model_Call_details and log it
|
||||||
|
@ -528,7 +517,7 @@ class Logging:
|
||||||
verbose_logger.debug("reaches sentry breadcrumbing")
|
verbose_logger.debug("reaches sentry breadcrumbing")
|
||||||
try:
|
try:
|
||||||
details_to_log = copy.deepcopy(self.model_call_details)
|
details_to_log = copy.deepcopy(self.model_call_details)
|
||||||
except:
|
except Exception:
|
||||||
details_to_log = self.model_call_details
|
details_to_log = self.model_call_details
|
||||||
if litellm.turn_off_message_logging:
|
if litellm.turn_off_message_logging:
|
||||||
# make a copy of the _model_Call_details and log it
|
# make a copy of the _model_Call_details and log it
|
||||||
|
@ -1326,7 +1315,7 @@ class Logging:
|
||||||
and customLogger is not None
|
and customLogger is not None
|
||||||
): # custom logger functions
|
): # custom logger functions
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"success callbacks: Running Custom Callback Function"
|
"success callbacks: Running Custom Callback Function"
|
||||||
)
|
)
|
||||||
customLogger.log_event(
|
customLogger.log_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -1400,7 +1389,7 @@ class Logging:
|
||||||
self.model_call_details["response_cost"] = 0.0
|
self.model_call_details["response_cost"] = 0.0
|
||||||
else:
|
else:
|
||||||
# check if base_model set on azure
|
# check if base_model set on azure
|
||||||
base_model = _get_base_model_from_metadata(
|
_get_base_model_from_metadata(
|
||||||
model_call_details=self.model_call_details
|
model_call_details=self.model_call_details
|
||||||
)
|
)
|
||||||
# base_model defaults to None if not set on model_info
|
# base_model defaults to None if not set on model_info
|
||||||
|
@ -1483,7 +1472,7 @@ class Logging:
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
# check if callback can run for this request
|
# check if callback can run for this request
|
||||||
litellm_params = self.model_call_details.get("litellm_params", {})
|
litellm_params = self.model_call_details.get("litellm_params", {})
|
||||||
if litellm_params.get("no-log", False) == True:
|
if litellm_params.get("no-log", False) is True:
|
||||||
# proxy cost tracking cal backs should run
|
# proxy cost tracking cal backs should run
|
||||||
if not (
|
if not (
|
||||||
isinstance(callback, CustomLogger)
|
isinstance(callback, CustomLogger)
|
||||||
|
@ -1492,7 +1481,7 @@ class Logging:
|
||||||
print_verbose("no-log request, skipping logging")
|
print_verbose("no-log request, skipping logging")
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
if kwargs.get("no-log", False) == True:
|
if kwargs.get("no-log", False) is True:
|
||||||
print_verbose("no-log request, skipping logging")
|
print_verbose("no-log request, skipping logging")
|
||||||
continue
|
continue
|
||||||
if (
|
if (
|
||||||
|
@ -1641,7 +1630,7 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
@ -2433,7 +2422,7 @@ def get_standard_logging_object_payload(
|
||||||
call_type = kwargs.get("call_type")
|
call_type = kwargs.get("call_type")
|
||||||
cache_hit = kwargs.get("cache_hit", False)
|
cache_hit = kwargs.get("cache_hit", False)
|
||||||
usage = response_obj.get("usage", None) or {}
|
usage = response_obj.get("usage", None) or {}
|
||||||
if type(usage) == litellm.Usage:
|
if type(usage) is litellm.Usage:
|
||||||
usage = dict(usage)
|
usage = dict(usage)
|
||||||
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||||
|
|
||||||
|
@ -2656,3 +2645,11 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||||
litellm_params["metadata"] = metadata
|
litellm_params["metadata"] = metadata
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
||||||
|
|
||||||
|
# integration helper function
|
||||||
|
def modify_integration(integration_name, integration_params):
|
||||||
|
global supabaseClient
|
||||||
|
if integration_name == "supabase":
|
||||||
|
if "table_name" in integration_params:
|
||||||
|
Supabase.supabase_table_name = integration_params["table_name"]
|
||||||
|
|
|
@ -45,7 +45,7 @@ def pick_cheapest_chat_model_from_llm_provider(custom_llm_provider: str):
|
||||||
model_info = litellm.get_model_info(
|
model_info = litellm.get_model_info(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
except:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
if model_info.get("mode") != "chat":
|
if model_info.get("mode") != "chat":
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -123,7 +123,7 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -165,7 +165,7 @@ def completion(
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise AI21Error(status_code=response.status_code, message=response.text)
|
raise AI21Error(status_code=response.status_code, message=response.text)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -191,7 +191,7 @@ def completion(
|
||||||
)
|
)
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response.choices = choices_list # type: ignore
|
model_response.choices = choices_list # type: ignore
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise AI21Error(
|
raise AI21Error(
|
||||||
message=traceback.format_exc(), status_code=response.status_code
|
message=traceback.format_exc(), status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
|
|
@ -151,7 +151,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aget_assistants=None,
|
aget_assistants=None,
|
||||||
):
|
):
|
||||||
if aget_assistants is not None and aget_assistants == True:
|
if aget_assistants is not None and aget_assistants is True:
|
||||||
return self.async_get_assistants(
|
return self.async_get_assistants(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -260,7 +260,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
a_add_message: Optional[bool] = None,
|
a_add_message: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
if a_add_message is not None and a_add_message == True:
|
if a_add_message is not None and a_add_message is True:
|
||||||
return self.a_add_message(
|
return self.a_add_message(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
message_data=message_data,
|
message_data=message_data,
|
||||||
|
@ -365,7 +365,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aget_messages=None,
|
aget_messages=None,
|
||||||
):
|
):
|
||||||
if aget_messages is not None and aget_messages == True:
|
if aget_messages is not None and aget_messages is True:
|
||||||
return self.async_get_messages(
|
return self.async_get_messages(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -483,7 +483,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
openai_api.create_thread(messages=[message])
|
openai_api.create_thread(messages=[message])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if acreate_thread is not None and acreate_thread == True:
|
if acreate_thread is not None and acreate_thread is True:
|
||||||
return self.async_create_thread(
|
return self.async_create_thread(
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -586,7 +586,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aget_thread=None,
|
aget_thread=None,
|
||||||
):
|
):
|
||||||
if aget_thread is not None and aget_thread == True:
|
if aget_thread is not None and aget_thread is True:
|
||||||
return self.async_get_thread(
|
return self.async_get_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -774,8 +774,8 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
arun_thread=None,
|
arun_thread=None,
|
||||||
event_handler: Optional[AssistantEventHandler] = None,
|
event_handler: Optional[AssistantEventHandler] = None,
|
||||||
):
|
):
|
||||||
if arun_thread is not None and arun_thread == True:
|
if arun_thread is not None and arun_thread is True:
|
||||||
if stream is not None and stream == True:
|
if stream is not None and stream is True:
|
||||||
azure_client = self.async_get_azure_client(
|
azure_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -823,7 +823,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream is not None and stream == True:
|
if stream is not None and stream is True:
|
||||||
return self.run_thread_stream(
|
return self.run_thread_stream(
|
||||||
client=openai_client,
|
client=openai_client,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
@ -887,7 +887,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
async_create_assistants=None,
|
async_create_assistants=None,
|
||||||
):
|
):
|
||||||
if async_create_assistants is not None and async_create_assistants == True:
|
if async_create_assistants is not None and async_create_assistants is True:
|
||||||
return self.async_create_assistants(
|
return self.async_create_assistants(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -950,7 +950,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
async_delete_assistants: Optional[bool] = None,
|
async_delete_assistants: Optional[bool] = None,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
if async_delete_assistants is not None and async_delete_assistants == True:
|
if async_delete_assistants is not None and async_delete_assistants is True:
|
||||||
return self.async_delete_assistant(
|
return self.async_delete_assistant(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
|
|
@ -317,7 +317,7 @@ class AzureOpenAIAssistantsAPIConfig:
|
||||||
if "file_id" in item:
|
if "file_id" in item:
|
||||||
file_ids.append(item["file_id"])
|
file_ids.append(item["file_id"])
|
||||||
else:
|
else:
|
||||||
if litellm.drop_params == True:
|
if litellm.drop_params is True:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise litellm.utils.UnsupportedParamsError(
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
@ -580,7 +580,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
try:
|
try:
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=422, message=f"Missing model or messages"
|
status_code=422, message="Missing model or messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_retries = optional_params.pop("max_retries", 2)
|
max_retries = optional_params.pop("max_retries", 2)
|
||||||
|
@ -1240,12 +1240,6 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||||
if time.time() - start_time > timeout_secs:
|
if time.time() - start_time > timeout_secs:
|
||||||
timeout_msg = {
|
|
||||||
"error": {
|
|
||||||
"code": "Timeout",
|
|
||||||
"message": "Operation polling timed out.",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=408, message="Operation polling timed out."
|
status_code=408, message="Operation polling timed out."
|
||||||
|
@ -1493,7 +1487,6 @@ class AzureChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aimg_generation=None,
|
aimg_generation=None,
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
|
||||||
try:
|
try:
|
||||||
if model and len(model) > 0:
|
if model and len(model) > 0:
|
||||||
model = model
|
model = model
|
||||||
|
@ -1534,7 +1527,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if aimg_generation == True:
|
if aimg_generation is True:
|
||||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
|
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -1263,7 +1263,6 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aimg_generation=None,
|
aimg_generation=None,
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
|
||||||
data = {}
|
data = {}
|
||||||
try:
|
try:
|
||||||
model = model
|
model = model
|
||||||
|
@ -1272,7 +1271,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
|
||||||
if aimg_generation == True:
|
if aimg_generation is True:
|
||||||
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -1311,7 +1310,6 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||||
except OpenAIError as e:
|
except OpenAIError as e:
|
||||||
|
|
||||||
exception_mapping_worked = True
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -1543,7 +1541,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
if (
|
if (
|
||||||
len(messages) > 0
|
len(messages) > 0
|
||||||
and "content" in messages[0]
|
and "content" in messages[0]
|
||||||
and type(messages[0]["content"]) == list
|
and isinstance(messages[0]["content"], list)
|
||||||
):
|
):
|
||||||
prompt = messages[0]["content"]
|
prompt = messages[0]["content"]
|
||||||
else:
|
else:
|
||||||
|
@ -2413,7 +2411,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aget_assistants=None,
|
aget_assistants=None,
|
||||||
):
|
):
|
||||||
if aget_assistants is not None and aget_assistants == True:
|
if aget_assistants is not None and aget_assistants is True:
|
||||||
return self.async_get_assistants(
|
return self.async_get_assistants(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -2470,7 +2468,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
async_create_assistants=None,
|
async_create_assistants=None,
|
||||||
):
|
):
|
||||||
if async_create_assistants is not None and async_create_assistants == True:
|
if async_create_assistants is not None and async_create_assistants is True:
|
||||||
return self.async_create_assistants(
|
return self.async_create_assistants(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -2527,7 +2525,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
async_delete_assistants=None,
|
async_delete_assistants=None,
|
||||||
):
|
):
|
||||||
if async_delete_assistants is not None and async_delete_assistants == True:
|
if async_delete_assistants is not None and async_delete_assistants is True:
|
||||||
return self.async_delete_assistant(
|
return self.async_delete_assistant(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -2629,7 +2627,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
a_add_message: Optional[bool] = None,
|
a_add_message: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
if a_add_message is not None and a_add_message == True:
|
if a_add_message is not None and a_add_message is True:
|
||||||
return self.a_add_message(
|
return self.a_add_message(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
message_data=message_data,
|
message_data=message_data,
|
||||||
|
@ -2727,7 +2725,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aget_messages=None,
|
aget_messages=None,
|
||||||
):
|
):
|
||||||
if aget_messages is not None and aget_messages == True:
|
if aget_messages is not None and aget_messages is True:
|
||||||
return self.async_get_messages(
|
return self.async_get_messages(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -2838,7 +2836,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
openai_api.create_thread(messages=[message])
|
openai_api.create_thread(messages=[message])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if acreate_thread is not None and acreate_thread == True:
|
if acreate_thread is not None and acreate_thread is True:
|
||||||
return self.async_create_thread(
|
return self.async_create_thread(
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -2934,7 +2932,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aget_thread=None,
|
aget_thread=None,
|
||||||
):
|
):
|
||||||
if aget_thread is not None and aget_thread == True:
|
if aget_thread is not None and aget_thread is True:
|
||||||
return self.async_get_thread(
|
return self.async_get_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -3117,8 +3115,8 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
arun_thread=None,
|
arun_thread=None,
|
||||||
event_handler: Optional[AssistantEventHandler] = None,
|
event_handler: Optional[AssistantEventHandler] = None,
|
||||||
):
|
):
|
||||||
if arun_thread is not None and arun_thread == True:
|
if arun_thread is not None and arun_thread is True:
|
||||||
if stream is not None and stream == True:
|
if stream is not None and stream is True:
|
||||||
_client = self.async_get_openai_client(
|
_client = self.async_get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -3163,7 +3161,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream is not None and stream == True:
|
if stream is not None and stream is True:
|
||||||
return self.run_thread_stream(
|
return self.run_thread_stream(
|
||||||
client=openai_client,
|
client=openai_client,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
|
|
@ -191,7 +191,7 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
default_max_tokens_to_sample=None,
|
default_max_tokens_to_sample=None,
|
||||||
|
@ -246,7 +246,7 @@ def completion(
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -279,7 +279,7 @@ def completion(
|
||||||
)
|
)
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response.choices = choices_list # type: ignore
|
model_response.choices = choices_list # type: ignore
|
||||||
except:
|
except Exception:
|
||||||
raise AlephAlphaError(
|
raise AlephAlphaError(
|
||||||
message=json.dumps(completion_response),
|
message=json.dumps(completion_response),
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
|
|
|
@ -607,7 +607,6 @@ class ModelResponseIterator:
|
||||||
def _handle_usage(
|
def _handle_usage(
|
||||||
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
||||||
) -> AnthropicChatCompletionUsageBlock:
|
) -> AnthropicChatCompletionUsageBlock:
|
||||||
special_fields = ["input_tokens", "output_tokens"]
|
|
||||||
|
|
||||||
usage_block = AnthropicChatCompletionUsageBlock(
|
usage_block = AnthropicChatCompletionUsageBlock(
|
||||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||||
|
@ -683,7 +682,7 @@ class ModelResponseIterator:
|
||||||
"index": self.tool_index,
|
"index": self.tool_index,
|
||||||
}
|
}
|
||||||
elif type_chunk == "content_block_stop":
|
elif type_chunk == "content_block_stop":
|
||||||
content_block_stop = ContentBlockStop(**chunk) # type: ignore
|
ContentBlockStop(**chunk) # type: ignore
|
||||||
# check if tool call content block
|
# check if tool call content block
|
||||||
is_empty = self.check_empty_tool_call_args()
|
is_empty = self.check_empty_tool_call_args()
|
||||||
if is_empty:
|
if is_empty:
|
||||||
|
|
|
@ -114,7 +114,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except Exception:
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
message=response.text, status_code=response.status_code
|
message=response.text, status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
@ -229,7 +229,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
|
@ -276,8 +276,8 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
model=model,
|
model=model,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -309,7 +309,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
return stream_response
|
return stream_response
|
||||||
elif acompletion == True:
|
elif acompletion is True:
|
||||||
return self.async_completion(
|
return self.async_completion(
|
||||||
model=model,
|
model=model,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
|
|
@ -233,7 +233,7 @@ class AzureTextCompletion(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return self.streaming(
|
return self.streaming(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
|
|
@ -36,7 +36,7 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -59,7 +59,7 @@ def completion(
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
"stream": (
|
"stream": (
|
||||||
True
|
True
|
||||||
if "stream" in optional_params and optional_params["stream"] == True
|
if "stream" in optional_params and optional_params["stream"] is True
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -77,12 +77,12 @@ def completion(
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
stream=(
|
stream=(
|
||||||
True
|
True
|
||||||
if "stream" in optional_params and optional_params["stream"] == True
|
if "stream" in optional_params and optional_params["stream"] is True
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if "text/event-stream" in response.headers["Content-Type"] or (
|
if "text/event-stream" in response.headers["Content-Type"] or (
|
||||||
"stream" in optional_params and optional_params["stream"] == True
|
"stream" in optional_params and optional_params["stream"] is True
|
||||||
):
|
):
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -183,7 +183,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
return litellm.AmazonConverseConfig()._transform_response(
|
return litellm.AmazonConverseConfig()._transform_response(
|
||||||
|
|
|
@ -251,9 +251,7 @@ class AmazonConverseConfig:
|
||||||
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
supported_guardrail_params = ["guardrailConfig"]
|
||||||
json_mode: Optional[bool] = inference_params.pop(
|
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||||
"json_mode", None
|
|
||||||
) # used for handling json_schema
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
|
|
@ -234,7 +234,7 @@ async def make_call(
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BedrockError(status_code=500, message=str(e))
|
raise BedrockError(status_code=500, message=str(e))
|
||||||
|
@ -335,7 +335,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except Exception:
|
||||||
raise BedrockError(message=response.text, status_code=422)
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
outputText: Optional[str] = None
|
outputText: Optional[str] = None
|
||||||
|
@ -394,12 +394,12 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
outputText # allow user to access raw anthropic tool calling response
|
outputText # allow user to access raw anthropic tool calling response
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
_is_function_call == True
|
_is_function_call is True
|
||||||
and stream is not None
|
and stream is not None
|
||||||
and stream == True
|
and stream is True
|
||||||
):
|
):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||||
)
|
)
|
||||||
# return an iterator
|
# return an iterator
|
||||||
streaming_model_response = ModelResponse(stream=True)
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
|
@ -440,7 +440,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
model_response=streaming_model_response
|
model_response=streaming_model_response
|
||||||
)
|
)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||||
)
|
)
|
||||||
return litellm.CustomStreamWrapper(
|
return litellm.CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -597,7 +597,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
from botocore.auth import SigV4Auth
|
from botocore.auth import SigV4Auth
|
||||||
from botocore.awsrequest import AWSRequest
|
from botocore.awsrequest import AWSRequest
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
## SETUP ##
|
## SETUP ##
|
||||||
|
@ -700,7 +700,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
k not in inference_params
|
k not in inference_params
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
inference_params[k] = v
|
inference_params[k] = v
|
||||||
if stream == True:
|
if stream is True:
|
||||||
inference_params["stream"] = (
|
inference_params["stream"] = (
|
||||||
True # cohere requires stream = True in inference params
|
True # cohere requires stream = True in inference params
|
||||||
)
|
)
|
||||||
|
@ -845,7 +845,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
if acompletion:
|
if acompletion:
|
||||||
if isinstance(client, HTTPHandler):
|
if isinstance(client, HTTPHandler):
|
||||||
client = None
|
client = None
|
||||||
if stream == True and provider != "ai21":
|
if stream is True and provider != "ai21":
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -891,7 +891,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
self.client = _get_httpx_client(_params) # type: ignore
|
self.client = _get_httpx_client(_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.client = client
|
self.client = client
|
||||||
if (stream is not None and stream == True) and provider != "ai21":
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
url=proxy_endpoint_url,
|
url=proxy_endpoint_url,
|
||||||
headers=prepped.headers, # type: ignore
|
headers=prepped.headers, # type: ignore
|
||||||
|
@ -929,7 +929,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
return self.process_response(
|
return self.process_response(
|
||||||
|
@ -980,7 +980,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
return self.process_response(
|
return self.process_response(
|
||||||
|
|
|
@ -260,7 +260,7 @@ class AmazonAnthropicConfig:
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if param == "stop":
|
if param == "stop":
|
||||||
optional_params["stop_sequences"] = value
|
optional_params["stop_sequences"] = value
|
||||||
if param == "stream" and value == True:
|
if param == "stream" and value is True:
|
||||||
optional_params["stream"] = value
|
optional_params["stream"] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
) -> Tuple[Any, str]:
|
) -> Tuple[Any, str]:
|
||||||
try:
|
try:
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
## CREDENTIALS ##
|
## CREDENTIALS ##
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
|
|
@ -130,7 +130,7 @@ def process_response(
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response.choices = choices_list # type: ignore
|
model_response.choices = choices_list # type: ignore
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise ClarifaiError(
|
raise ClarifaiError(
|
||||||
message=traceback.format_exc(), status_code=response.status_code, url=model
|
message=traceback.format_exc(), status_code=response.status_code, url=model
|
||||||
)
|
)
|
||||||
|
@ -219,7 +219,7 @@ async def async_completion(
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response.choices = choices_list # type: ignore
|
model_response.choices = choices_list # type: ignore
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise ClarifaiError(
|
raise ClarifaiError(
|
||||||
message=traceback.format_exc(), status_code=response.status_code, url=model
|
message=traceback.format_exc(), status_code=response.status_code, url=model
|
||||||
)
|
)
|
||||||
|
@ -251,9 +251,9 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
acompletion=False,
|
acompletion=False,
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -268,20 +268,12 @@ def completion(
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
|
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
|
||||||
if custom_llm_provider == "anthropic":
|
prompt: str = prompt_factory( # type: ignore
|
||||||
prompt = prompt_factory(
|
model=orig_model_name,
|
||||||
model=orig_model_name,
|
messages=messages,
|
||||||
messages=messages,
|
api_key=api_key,
|
||||||
api_key=api_key,
|
custom_llm_provider="clarifai",
|
||||||
custom_llm_provider="clarifai",
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = prompt_factory(
|
|
||||||
model=orig_model_name,
|
|
||||||
messages=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
)
|
|
||||||
# print(prompt); exit(0)
|
# print(prompt); exit(0)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
@ -300,7 +292,7 @@ def completion(
|
||||||
"api_base": model,
|
"api_base": model,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
return async_completion(
|
return async_completion(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
@ -331,7 +323,7 @@ def completion(
|
||||||
status_code=response.status_code, message=response.text, url=model
|
status_code=response.status_code, message=response.text, url=model
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
completion_stream = response.iter_lines()
|
completion_stream = response.iter_lines()
|
||||||
stream_response = CustomStreamWrapper(
|
stream_response = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
|
|
@ -80,8 +80,8 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -97,7 +97,7 @@ def completion(
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt = custom_prompt(
|
custom_prompt(
|
||||||
role_dict=model_prompt_details.get("roles", {}),
|
role_dict=model_prompt_details.get("roles", {}),
|
||||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
|
@ -126,7 +126,7 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
api_base,
|
api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|
|
@ -268,7 +268,7 @@ def completion(
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise CohereError(message=response.text, status_code=response.status_code)
|
raise CohereError(message=response.text, status_code=response.status_code)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -283,12 +283,12 @@ def completion(
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
try:
|
try:
|
||||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise CohereError(message=response.text, status_code=response.status_code)
|
raise CohereError(message=response.text, status_code=response.status_code)
|
||||||
|
|
||||||
## Tool calling response
|
## Tool calling response
|
||||||
cohere_tools_response = completion_response.get("tool_calls", None)
|
cohere_tools_response = completion_response.get("tool_calls", None)
|
||||||
if cohere_tools_response is not None and cohere_tools_response is not []:
|
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||||
# convert cohere_tools_response to OpenAI response format
|
# convert cohere_tools_response to OpenAI response format
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for tool in cohere_tools_response:
|
for tool in cohere_tools_response:
|
||||||
|
|
|
@ -146,7 +146,7 @@ def completion(
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -198,7 +198,7 @@ def completion(
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise CohereError(message=response.text, status_code=response.status_code)
|
raise CohereError(message=response.text, status_code=response.status_code)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -231,7 +231,7 @@ def completion(
|
||||||
)
|
)
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response.choices = choices_list # type: ignore
|
model_response.choices = choices_list # type: ignore
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise CohereError(
|
raise CohereError(
|
||||||
message=response.text, status_code=response.status_code
|
message=response.text, status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,7 +17,7 @@ else:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
except:
|
except Exception:
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
except:
|
except Exception:
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": f"litellm/{version}",
|
"User-Agent": f"litellm/{version}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class HTTPHandler:
|
class HTTPHandler:
|
||||||
def __init__(self, concurrent_limit=1000):
|
def __init__(self, concurrent_limit=1000):
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
|
|
|
@ -113,7 +113,7 @@ class DatabricksConfig:
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
if param == "n":
|
if param == "n":
|
||||||
optional_params["n"] = value
|
optional_params["n"] = value
|
||||||
if param == "stream" and value == True:
|
if param == "stream" and value is True:
|
||||||
optional_params["stream"] = value
|
optional_params["stream"] = value
|
||||||
if param == "temperature":
|
if param == "temperature":
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
|
@ -564,7 +564,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
status_code=e.response.status_code,
|
status_code=e.response.status_code,
|
||||||
message=e.response.text,
|
message=e.response.text,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise DatabricksError(
|
raise DatabricksError(
|
||||||
status_code=408, message="Timeout error occurred."
|
status_code=408, message="Timeout error occurred."
|
||||||
)
|
)
|
||||||
|
@ -614,7 +614,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
status_code=e.response.status_code,
|
status_code=e.response.status_code,
|
||||||
message=response.text if response else str(e),
|
message=response.text if response else str(e),
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise DatabricksError(
|
raise DatabricksError(
|
||||||
status_code=408, message="Timeout error occurred."
|
status_code=408, message="Timeout error occurred."
|
||||||
)
|
)
|
||||||
|
@ -669,7 +669,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||||
)
|
)
|
||||||
|
|
||||||
if aembedding == True:
|
if aembedding is True:
|
||||||
return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
|
return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
self.client = HTTPHandler(timeout=timeout) # type: ignore
|
self.client = HTTPHandler(timeout=timeout) # type: ignore
|
||||||
|
@ -692,7 +692,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
status_code=e.response.status_code,
|
status_code=e.response.status_code,
|
||||||
message=e.response.text,
|
message=e.response.text,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabricksError(status_code=500, message=str(e))
|
raise DatabricksError(status_code=500, message=str(e))
|
||||||
|
|
|
@ -71,7 +71,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
self,
|
self,
|
||||||
_is_async: bool,
|
_is_async: bool,
|
||||||
create_file_data: CreateFileRequest,
|
create_file_data: CreateFileRequest,
|
||||||
api_base: str,
|
api_base: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
api_version: Optional[str],
|
api_version: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
@ -117,7 +117,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
self,
|
self,
|
||||||
_is_async: bool,
|
_is_async: bool,
|
||||||
file_content_request: FileContentRequest,
|
file_content_request: FileContentRequest,
|
||||||
api_base: str,
|
api_base: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
|
@ -168,7 +168,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
self,
|
self,
|
||||||
_is_async: bool,
|
_is_async: bool,
|
||||||
file_id: str,
|
file_id: str,
|
||||||
api_base: str,
|
api_base: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
|
@ -220,7 +220,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
self,
|
self,
|
||||||
_is_async: bool,
|
_is_async: bool,
|
||||||
file_id: str,
|
file_id: str,
|
||||||
api_base: str,
|
api_base: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
|
@ -275,7 +275,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
def list_files(
|
def list_files(
|
||||||
self,
|
self,
|
||||||
_is_async: bool,
|
_is_async: bool,
|
||||||
api_base: str,
|
api_base: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
|
|
|
@ -41,7 +41,7 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
created_at = int(create_time_datetime.timestamp())
|
created_at = int(create_time_datetime.timestamp())
|
||||||
|
|
||||||
return created_at
|
return created_at
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def convert_vertex_response_to_open_ai_response(
|
def convert_vertex_response_to_open_ai_response(
|
||||||
|
|
|
@ -136,7 +136,7 @@ class GeminiConfig:
|
||||||
# ):
|
# ):
|
||||||
# try:
|
# try:
|
||||||
# import google.generativeai as genai # type: ignore
|
# import google.generativeai as genai # type: ignore
|
||||||
# except:
|
# except Exception:
|
||||||
# raise Exception(
|
# raise Exception(
|
||||||
# "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
# "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||||
# )
|
# )
|
||||||
|
@ -282,7 +282,7 @@ class GeminiConfig:
|
||||||
# completion_response = model_response["choices"][0]["message"].get("content")
|
# completion_response = model_response["choices"][0]["message"].get("content")
|
||||||
# if completion_response is None:
|
# if completion_response is None:
|
||||||
# raise Exception
|
# raise Exception
|
||||||
# except:
|
# except Exception:
|
||||||
# original_response = f"response: {response}"
|
# original_response = f"response: {response}"
|
||||||
# if hasattr(response, "candidates"):
|
# if hasattr(response, "candidates"):
|
||||||
# original_response = f"response: {response.candidates}"
|
# original_response = f"response: {response.candidates}"
|
||||||
|
@ -374,7 +374,7 @@ class GeminiConfig:
|
||||||
# completion_response = model_response["choices"][0]["message"].get("content")
|
# completion_response = model_response["choices"][0]["message"].get("content")
|
||||||
# if completion_response is None:
|
# if completion_response is None:
|
||||||
# raise Exception
|
# raise Exception
|
||||||
# except:
|
# except Exception:
|
||||||
# original_response = f"response: {response}"
|
# original_response = f"response: {response}"
|
||||||
# if hasattr(response, "candidates"):
|
# if hasattr(response, "candidates"):
|
||||||
# original_response = f"response: {response.candidates}"
|
# original_response = f"response: {response.candidates}"
|
||||||
|
|
|
@ -13,6 +13,7 @@ import requests
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
@ -181,7 +182,7 @@ class HuggingfaceConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
def get_hf_api_key(self) -> Optional[str]:
|
def get_hf_api_key(self) -> Optional[str]:
|
||||||
return litellm.utils.get_secret("HUGGINGFACE_API_KEY")
|
return get_secret_str("HUGGINGFACE_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
def output_parser(generated_text: str):
|
def output_parser(generated_text: str):
|
||||||
|
@ -240,7 +241,7 @@ def read_tgi_conv_models():
|
||||||
# Cache the set for future use
|
# Cache the set for future use
|
||||||
conv_models_cache = conv_models
|
conv_models_cache = conv_models
|
||||||
return tgi_models, conv_models
|
return tgi_models, conv_models
|
||||||
except:
|
except Exception:
|
||||||
return set(), set()
|
return set(), set()
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,7 +373,7 @@ class Huggingface(BaseLLM):
|
||||||
]["finish_reason"]
|
]["finish_reason"]
|
||||||
sum_logprob = 0
|
sum_logprob = 0
|
||||||
for token in completion_response[0]["details"]["tokens"]:
|
for token in completion_response[0]["details"]["tokens"]:
|
||||||
if token["logprob"] != None:
|
if token["logprob"] is not None:
|
||||||
sum_logprob += token["logprob"]
|
sum_logprob += token["logprob"]
|
||||||
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
|
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
|
||||||
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||||
|
@ -386,7 +387,7 @@ class Huggingface(BaseLLM):
|
||||||
):
|
):
|
||||||
sum_logprob = 0
|
sum_logprob = 0
|
||||||
for token in item["tokens"]:
|
for token in item["tokens"]:
|
||||||
if token["logprob"] != None:
|
if token["logprob"] is not None:
|
||||||
sum_logprob += token["logprob"]
|
sum_logprob += token["logprob"]
|
||||||
if len(item["generated_text"]) > 0:
|
if len(item["generated_text"]) > 0:
|
||||||
message_obj = Message(
|
message_obj = Message(
|
||||||
|
@ -417,7 +418,7 @@ class Huggingface(BaseLLM):
|
||||||
prompt_tokens = len(
|
prompt_tokens = len(
|
||||||
encoding.encode(input_text)
|
encoding.encode(input_text)
|
||||||
) ##[TODO] use the llama2 tokenizer here
|
) ##[TODO] use the llama2 tokenizer here
|
||||||
except:
|
except Exception:
|
||||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
pass
|
pass
|
||||||
output_text = model_response["choices"][0]["message"].get("content", "")
|
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||||
|
@ -429,7 +430,7 @@ class Huggingface(BaseLLM):
|
||||||
model_response["choices"][0]["message"].get("content", "")
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
)
|
)
|
||||||
) ##[TODO] use the llama2 tokenizer here
|
) ##[TODO] use the llama2 tokenizer here
|
||||||
except:
|
except Exception:
|
||||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -559,7 +560,7 @@ class Huggingface(BaseLLM):
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and isinstance(optional_params["stream"], bool)
|
and isinstance(optional_params["stream"], bool)
|
||||||
and optional_params["stream"] == True # type: ignore
|
and optional_params["stream"] is True # type: ignore
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -595,7 +596,7 @@ class Huggingface(BaseLLM):
|
||||||
data["stream"] = ( # type: ignore
|
data["stream"] = ( # type: ignore
|
||||||
True # type: ignore
|
True # type: ignore
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] is True
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
|
@ -631,7 +632,7 @@ class Huggingface(BaseLLM):
|
||||||
### ASYNC COMPLETION
|
### ASYNC COMPLETION
|
||||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
|
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
|
||||||
### SYNC STREAMING
|
### SYNC STREAMING
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
completion_url,
|
completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -691,7 +692,7 @@ class Huggingface(BaseLLM):
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
if isinstance(completion_response, dict):
|
if isinstance(completion_response, dict):
|
||||||
completion_response = [completion_response]
|
completion_response = [completion_response]
|
||||||
except:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
|
|
|
@ -101,7 +101,7 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -135,7 +135,7 @@ def completion(
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -159,7 +159,7 @@ def completion(
|
||||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||||
"answer"
|
"answer"
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise MaritalkError(
|
raise MaritalkError(
|
||||||
message=response.text, status_code=response.status_code
|
message=response.text, status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
|
|
@ -120,7 +120,7 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
default_max_tokens_to_sample=None,
|
default_max_tokens_to_sample=None,
|
||||||
|
@ -164,7 +164,7 @@ def completion(
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return clean_and_iterate_chunks(response)
|
return clean_and_iterate_chunks(response)
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -178,7 +178,7 @@ def completion(
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except Exception:
|
||||||
raise NLPCloudError(message=response.text, status_code=response.status_code)
|
raise NLPCloudError(message=response.text, status_code=response.status_code)
|
||||||
if "error" in completion_response:
|
if "error" in completion_response:
|
||||||
raise NLPCloudError(
|
raise NLPCloudError(
|
||||||
|
@ -191,7 +191,7 @@ def completion(
|
||||||
model_response.choices[0].message.content = ( # type: ignore
|
model_response.choices[0].message.content = ( # type: ignore
|
||||||
completion_response["generated_text"]
|
completion_response["generated_text"]
|
||||||
)
|
)
|
||||||
except:
|
except Exception:
|
||||||
raise NLPCloudError(
|
raise NLPCloudError(
|
||||||
message=json.dumps(completion_response),
|
message=json.dumps(completion_response),
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
|
|
|
@ -14,7 +14,7 @@ import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ProviderField, StreamingChoices
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ def _convert_image(image):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"ollama image conversion failed please run `pip install Pillow`"
|
"ollama image conversion failed please run `pip install Pillow`"
|
||||||
)
|
)
|
||||||
|
@ -184,7 +184,7 @@ def _convert_image(image):
|
||||||
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
|
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
|
||||||
if image_data.format in ["JPEG", "PNG"]:
|
if image_data.format in ["JPEG", "PNG"]:
|
||||||
return image
|
return image
|
||||||
except:
|
except Exception:
|
||||||
return orig
|
return orig
|
||||||
jpeg_image = io.BytesIO()
|
jpeg_image = io.BytesIO()
|
||||||
image_data.convert("RGB").save(jpeg_image, "JPEG")
|
image_data.convert("RGB").save(jpeg_image, "JPEG")
|
||||||
|
@ -195,13 +195,13 @@ def _convert_image(image):
|
||||||
# ollama implementation
|
# ollama implementation
|
||||||
def get_ollama_response(
|
def get_ollama_response(
|
||||||
model_response: litellm.ModelResponse,
|
model_response: litellm.ModelResponse,
|
||||||
api_base="http://localhost:11434",
|
model: str,
|
||||||
model="llama2",
|
prompt: str,
|
||||||
prompt="Why is the sky blue?",
|
optional_params: dict,
|
||||||
optional_params=None,
|
logging_obj: Any,
|
||||||
logging_obj=None,
|
encoding: Any,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
encoding=None,
|
api_base="http://localhost:11434",
|
||||||
):
|
):
|
||||||
if api_base.endswith("/api/generate"):
|
if api_base.endswith("/api/generate"):
|
||||||
url = api_base
|
url = api_base
|
||||||
|
@ -242,7 +242,7 @@ def get_ollama_response(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
if stream == True:
|
if stream is True:
|
||||||
response = ollama_async_streaming(
|
response = ollama_async_streaming(
|
||||||
url=url,
|
url=url,
|
||||||
data=data,
|
data=data,
|
||||||
|
@ -340,11 +340,16 @@ def ollama_completion_stream(url, data, logging_obj):
|
||||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
if data.get("format", "") == "json":
|
if data.get("format", "") == "json":
|
||||||
first_chunk = next(streamwrapper)
|
first_chunk = next(streamwrapper)
|
||||||
response_content = "".join(
|
content_chunks = []
|
||||||
chunk.choices[0].delta.content
|
for chunk in chain([first_chunk], streamwrapper):
|
||||||
for chunk in chain([first_chunk], streamwrapper)
|
content_chunk = chunk.choices[0]
|
||||||
if chunk.choices[0].delta.content
|
if (
|
||||||
)
|
isinstance(content_chunk, StreamingChoices)
|
||||||
|
and hasattr(content_chunk, "delta")
|
||||||
|
and hasattr(content_chunk.delta, "content")
|
||||||
|
):
|
||||||
|
content_chunks.append(content_chunk.delta.content)
|
||||||
|
response_content = "".join(content_chunks)
|
||||||
|
|
||||||
function_call = json.loads(response_content)
|
function_call = json.loads(response_content)
|
||||||
delta = litellm.utils.Delta(
|
delta = litellm.utils.Delta(
|
||||||
|
@ -392,15 +397,27 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
||||||
# If format is JSON, this was a function call
|
# If format is JSON, this was a function call
|
||||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
if data.get("format", "") == "json":
|
if data.get("format", "") == "json":
|
||||||
first_chunk = await anext(streamwrapper)
|
first_chunk = await anext(streamwrapper) # noqa F821
|
||||||
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
chunk_choice = first_chunk.choices[0]
|
||||||
response_content = first_chunk_content + "".join(
|
if (
|
||||||
[
|
isinstance(chunk_choice, StreamingChoices)
|
||||||
chunk.choices[0].delta.content
|
and hasattr(chunk_choice, "delta")
|
||||||
async for chunk in streamwrapper
|
and hasattr(chunk_choice.delta, "content")
|
||||||
if chunk.choices[0].delta.content
|
):
|
||||||
]
|
first_chunk_content = chunk_choice.delta.content or ""
|
||||||
)
|
else:
|
||||||
|
first_chunk_content = ""
|
||||||
|
|
||||||
|
content_chunks = []
|
||||||
|
async for chunk in streamwrapper:
|
||||||
|
chunk_choice = chunk.choices[0]
|
||||||
|
if (
|
||||||
|
isinstance(chunk_choice, StreamingChoices)
|
||||||
|
and hasattr(chunk_choice, "delta")
|
||||||
|
and hasattr(chunk_choice.delta, "content")
|
||||||
|
):
|
||||||
|
content_chunks.append(chunk_choice.delta.content)
|
||||||
|
response_content = first_chunk_content + "".join(content_chunks)
|
||||||
function_call = json.loads(response_content)
|
function_call = json.loads(response_content)
|
||||||
delta = litellm.utils.Delta(
|
delta = litellm.utils.Delta(
|
||||||
content=None,
|
content=None,
|
||||||
|
@ -501,8 +518,8 @@ async def ollama_aembeddings(
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
logging_obj=None,
|
logging_obj: Any,
|
||||||
encoding=None,
|
encoding: Any,
|
||||||
):
|
):
|
||||||
if api_base.endswith("/api/embed"):
|
if api_base.endswith("/api/embed"):
|
||||||
url = api_base
|
url = api_base
|
||||||
|
@ -581,9 +598,9 @@ def ollama_embeddings(
|
||||||
api_base: str,
|
api_base: str,
|
||||||
model: str,
|
model: str,
|
||||||
prompts: list,
|
prompts: list,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
logging_obj=None,
|
model_response: litellm.EmbeddingResponse,
|
||||||
model_response=None,
|
logging_obj: Any,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
return asyncio.run(
|
return asyncio.run(
|
||||||
|
|
|
@ -4,7 +4,7 @@ import traceback
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -15,6 +15,7 @@ import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||||
|
from litellm.types.utils import StreamingChoices
|
||||||
|
|
||||||
|
|
||||||
class OllamaError(Exception):
|
class OllamaError(Exception):
|
||||||
|
@ -216,10 +217,10 @@ def get_ollama_response(
|
||||||
model_response: litellm.ModelResponse,
|
model_response: litellm.ModelResponse,
|
||||||
messages: list,
|
messages: list,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
logging_obj: Any,
|
||||||
api_base="http://localhost:11434",
|
api_base="http://localhost:11434",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
model="llama2",
|
|
||||||
logging_obj=None,
|
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
|
@ -252,10 +253,13 @@ def get_ollama_response(
|
||||||
for tool in m["tool_calls"]:
|
for tool in m["tool_calls"]:
|
||||||
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
|
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
|
||||||
if typed_tool["type"] == "function":
|
if typed_tool["type"] == "function":
|
||||||
|
arguments = {}
|
||||||
|
if "arguments" in typed_tool["function"]:
|
||||||
|
arguments = json.loads(typed_tool["function"]["arguments"])
|
||||||
ollama_tool_call = OllamaToolCall(
|
ollama_tool_call = OllamaToolCall(
|
||||||
function=OllamaToolCallFunction(
|
function=OllamaToolCallFunction(
|
||||||
name=typed_tool["function"]["name"],
|
name=typed_tool["function"].get("name") or "",
|
||||||
arguments=json.loads(typed_tool["function"]["arguments"]),
|
arguments=arguments,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
new_tools.append(ollama_tool_call)
|
new_tools.append(ollama_tool_call)
|
||||||
|
@ -401,12 +405,16 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||||
# If format is JSON, this was a function call
|
# If format is JSON, this was a function call
|
||||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
if data.get("format", "") == "json":
|
if data.get("format", "") == "json":
|
||||||
first_chunk = next(streamwrapper)
|
content_chunks = []
|
||||||
response_content = "".join(
|
for chunk in streamwrapper:
|
||||||
chunk.choices[0].delta.content
|
chunk_choice = chunk.choices[0]
|
||||||
for chunk in chain([first_chunk], streamwrapper)
|
if (
|
||||||
if chunk.choices[0].delta.content
|
isinstance(chunk_choice, StreamingChoices)
|
||||||
)
|
and hasattr(chunk_choice, "delta")
|
||||||
|
and hasattr(chunk_choice.delta, "content")
|
||||||
|
):
|
||||||
|
content_chunks.append(chunk_choice.delta.content)
|
||||||
|
response_content = "".join(content_chunks)
|
||||||
|
|
||||||
function_call = json.loads(response_content)
|
function_call = json.loads(response_content)
|
||||||
delta = litellm.utils.Delta(
|
delta = litellm.utils.Delta(
|
||||||
|
@ -422,7 +430,7 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response = first_chunk
|
model_response = content_chunks[0]
|
||||||
model_response.choices[0].delta = delta # type: ignore
|
model_response.choices[0].delta = delta # type: ignore
|
||||||
model_response.choices[0].finish_reason = "tool_calls"
|
model_response.choices[0].finish_reason = "tool_calls"
|
||||||
yield model_response
|
yield model_response
|
||||||
|
@ -462,15 +470,28 @@ async def ollama_async_streaming(
|
||||||
# If format is JSON, this was a function call
|
# If format is JSON, this was a function call
|
||||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
if data.get("format", "") == "json":
|
if data.get("format", "") == "json":
|
||||||
first_chunk = await anext(streamwrapper)
|
first_chunk = await anext(streamwrapper) # noqa F821
|
||||||
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
chunk_choice = first_chunk.choices[0]
|
||||||
response_content = first_chunk_content + "".join(
|
if (
|
||||||
[
|
isinstance(chunk_choice, StreamingChoices)
|
||||||
chunk.choices[0].delta.content
|
and hasattr(chunk_choice, "delta")
|
||||||
async for chunk in streamwrapper
|
and hasattr(chunk_choice.delta, "content")
|
||||||
if chunk.choices[0].delta.content
|
):
|
||||||
]
|
first_chunk_content = chunk_choice.delta.content or ""
|
||||||
)
|
else:
|
||||||
|
first_chunk_content = ""
|
||||||
|
|
||||||
|
content_chunks = []
|
||||||
|
async for chunk in streamwrapper:
|
||||||
|
chunk_choice = chunk.choices[0]
|
||||||
|
if (
|
||||||
|
isinstance(chunk_choice, StreamingChoices)
|
||||||
|
and hasattr(chunk_choice, "delta")
|
||||||
|
and hasattr(chunk_choice.delta, "content")
|
||||||
|
):
|
||||||
|
content_chunks.append(chunk_choice.delta.content)
|
||||||
|
response_content = first_chunk_content + "".join(content_chunks)
|
||||||
|
|
||||||
function_call = json.loads(response_content)
|
function_call = json.loads(response_content)
|
||||||
delta = litellm.utils.Delta(
|
delta = litellm.utils.Delta(
|
||||||
content=None,
|
content=None,
|
||||||
|
|
|
@ -39,8 +39,8 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
default_max_tokens_to_sample=None,
|
default_max_tokens_to_sample=None,
|
||||||
|
@ -77,7 +77,7 @@ def completion(
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -91,7 +91,7 @@ def completion(
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except Exception:
|
||||||
raise OobaboogaError(
|
raise OobaboogaError(
|
||||||
message=response.text, status_code=response.status_code
|
message=response.text, status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
@ -103,7 +103,7 @@ def completion(
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
|
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
|
||||||
except:
|
except Exception:
|
||||||
raise OobaboogaError(
|
raise OobaboogaError(
|
||||||
message=json.dumps(completion_response),
|
message=json.dumps(completion_response),
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
|
|
|
@ -96,13 +96,13 @@ def completion(
|
||||||
api_key,
|
api_key,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import google.generativeai as palm # type: ignore
|
import google.generativeai as palm # type: ignore
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||||
)
|
)
|
||||||
|
@ -167,14 +167,14 @@ def completion(
|
||||||
choice_obj = Choices(index=idx + 1, message=message_obj)
|
choice_obj = Choices(index=idx + 1, message=message_obj)
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response.choices = choices_list # type: ignore
|
model_response.choices = choices_list # type: ignore
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise PalmError(
|
raise PalmError(
|
||||||
message=traceback.format_exc(), status_code=response.status_code
|
message=traceback.format_exc(), status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
completion_response = model_response["choices"][0]["message"].get("content")
|
completion_response = model_response["choices"][0]["message"].get("content")
|
||||||
except:
|
except Exception:
|
||||||
raise PalmError(
|
raise PalmError(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message=f"No response received. Original response - {response}",
|
message=f"No response received. Original response - {response}",
|
||||||
|
|
|
@ -98,7 +98,7 @@ def completion(
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
stream=False,
|
stream=False,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -123,6 +123,7 @@ def completion(
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
output_text: Optional[str] = None
|
||||||
if api_base:
|
if api_base:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -157,7 +158,7 @@ def completion(
|
||||||
import torch
|
import torch
|
||||||
from petals import AutoDistributedModelForCausalLM # type: ignore
|
from petals import AutoDistributedModelForCausalLM # type: ignore
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
|
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
|
||||||
)
|
)
|
||||||
|
@ -192,7 +193,7 @@ def completion(
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
output_text = tokenizer.decode(outputs[0])
|
output_text = tokenizer.decode(outputs[0])
|
||||||
|
|
||||||
if len(output_text) > 0:
|
if output_text is not None and len(output_text) > 0:
|
||||||
model_response.choices[0].message.content = output_text # type: ignore
|
model_response.choices[0].message.content = output_text # type: ignore
|
||||||
|
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
|
|
|
@ -265,7 +265,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except Exception:
|
||||||
raise PredibaseError(message=response.text, status_code=422)
|
raise PredibaseError(message=response.text, status_code=422)
|
||||||
if "error" in completion_response:
|
if "error" in completion_response:
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
|
@ -348,7 +348,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
model_response["choices"][0]["message"].get("content", "")
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
)
|
)
|
||||||
) ##[TODO] use a model-specific tokenizer
|
) ##[TODO] use a model-specific tokenizer
|
||||||
except:
|
except Exception:
|
||||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -5,7 +5,7 @@ import traceback
|
||||||
import uuid
|
import uuid
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, cast
|
||||||
|
|
||||||
from jinja2 import BaseLoader, Template, exceptions, meta
|
from jinja2 import BaseLoader, Template, exceptions, meta
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
|
@ -26,11 +26,14 @@ from litellm.types.completion import (
|
||||||
)
|
)
|
||||||
from litellm.types.llms.anthropic import *
|
from litellm.types.llms.anthropic import *
|
||||||
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||||
|
from litellm.types.llms.ollama import OllamaVisionModelObject
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
ChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessage,
|
||||||
ChatCompletionAssistantToolCall,
|
ChatCompletionAssistantToolCall,
|
||||||
ChatCompletionFunctionMessage,
|
ChatCompletionFunctionMessage,
|
||||||
|
ChatCompletionImageObject,
|
||||||
|
ChatCompletionTextObject,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionToolMessage,
|
ChatCompletionToolMessage,
|
||||||
ChatCompletionUserMessage,
|
ChatCompletionUserMessage,
|
||||||
|
@ -164,7 +167,9 @@ def convert_to_ollama_image(openai_image_url: str):
|
||||||
|
|
||||||
def ollama_pt(
|
def ollama_pt(
|
||||||
model, messages
|
model, messages
|
||||||
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
) -> Union[
|
||||||
|
str, OllamaVisionModelObject
|
||||||
|
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||||
if "instruct" in model:
|
if "instruct" in model:
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
role_dict={
|
role_dict={
|
||||||
|
@ -438,7 +443,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
def _is_system_in_template():
|
def _is_system_in_template():
|
||||||
try:
|
try:
|
||||||
# Try rendering the template with a system message
|
# Try rendering the template with a system message
|
||||||
response = template.render(
|
template.render(
|
||||||
messages=[{"role": "system", "content": "test"}],
|
messages=[{"role": "system", "content": "test"}],
|
||||||
eos_token="<eos>",
|
eos_token="<eos>",
|
||||||
bos_token="<bos>",
|
bos_token="<bos>",
|
||||||
|
@ -446,10 +451,11 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# This will be raised if Jinja attempts to render the system message and it can't
|
# This will be raised if Jinja attempts to render the system message and it can't
|
||||||
except:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
rendered_text = ""
|
||||||
# Render the template with the provided values
|
# Render the template with the provided values
|
||||||
if _is_system_in_template():
|
if _is_system_in_template():
|
||||||
rendered_text = template.render(
|
rendered_text = template.render(
|
||||||
|
@ -460,8 +466,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# treat a system message as a user message, if system not in template
|
# treat a system message as a user message, if system not in template
|
||||||
|
reformatted_messages = []
|
||||||
try:
|
try:
|
||||||
reformatted_messages = []
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message["role"] == "system":
|
if message["role"] == "system":
|
||||||
reformatted_messages.append(
|
reformatted_messages.append(
|
||||||
|
@ -556,30 +562,31 @@ def get_model_info(token, model):
|
||||||
return None, None
|
return None, None
|
||||||
else:
|
else:
|
||||||
return None, None
|
return None, None
|
||||||
except Exception as e: # safely fail a prompt template request
|
except Exception: # safely fail a prompt template request
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
## OLD TOGETHER AI FLOW
|
||||||
if prompt_format is None:
|
# def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||||
return default_pt(messages)
|
# if prompt_format is None:
|
||||||
|
# return default_pt(messages)
|
||||||
|
|
||||||
human_prompt, assistant_prompt = prompt_format.split("{prompt}")
|
# human_prompt, assistant_prompt = prompt_format.split("{prompt}")
|
||||||
|
|
||||||
if chat_template is not None:
|
# if chat_template is not None:
|
||||||
prompt = hf_chat_template(
|
# prompt = hf_chat_template(
|
||||||
model=None, messages=messages, chat_template=chat_template
|
# model=None, messages=messages, chat_template=chat_template
|
||||||
)
|
# )
|
||||||
elif prompt_format is not None:
|
# elif prompt_format is not None:
|
||||||
prompt = custom_prompt(
|
# prompt = custom_prompt(
|
||||||
role_dict={},
|
# role_dict={},
|
||||||
messages=messages,
|
# messages=messages,
|
||||||
initial_prompt_value=human_prompt,
|
# initial_prompt_value=human_prompt,
|
||||||
final_prompt_value=assistant_prompt,
|
# final_prompt_value=assistant_prompt,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
prompt = default_pt(messages)
|
# prompt = default_pt(messages)
|
||||||
return prompt
|
# return prompt
|
||||||
|
|
||||||
|
|
||||||
### IBM Granite
|
### IBM Granite
|
||||||
|
@ -1063,7 +1070,7 @@ def convert_to_gemini_tool_call_invoke(
|
||||||
else: # don't silently drop params. Make it clear to user what's happening.
|
else: # don't silently drop params. Make it clear to user what's happening.
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
|
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
|
||||||
tool
|
message
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return _parts_list
|
return _parts_list
|
||||||
|
@ -1216,12 +1223,14 @@ def convert_function_to_anthropic_tool_invoke(
|
||||||
function_call: Union[dict, ChatCompletionToolCallFunctionChunk],
|
function_call: Union[dict, ChatCompletionToolCallFunctionChunk],
|
||||||
) -> List[AnthropicMessagesToolUseParam]:
|
) -> List[AnthropicMessagesToolUseParam]:
|
||||||
try:
|
try:
|
||||||
|
_name = get_attribute_or_key(function_call, "name") or ""
|
||||||
|
_arguments = get_attribute_or_key(function_call, "arguments")
|
||||||
anthropic_tool_invoke = [
|
anthropic_tool_invoke = [
|
||||||
AnthropicMessagesToolUseParam(
|
AnthropicMessagesToolUseParam(
|
||||||
type="tool_use",
|
type="tool_use",
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
name=get_attribute_or_key(function_call, "name"),
|
name=_name,
|
||||||
input=json.loads(get_attribute_or_key(function_call, "arguments")),
|
input=json.loads(_arguments) if _arguments else {},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return anthropic_tool_invoke
|
return anthropic_tool_invoke
|
||||||
|
@ -1349,8 +1358,9 @@ def anthropic_messages_pt(
|
||||||
):
|
):
|
||||||
for m in user_message_types_block["content"]:
|
for m in user_message_types_block["content"]:
|
||||||
if m.get("type", "") == "image_url":
|
if m.get("type", "") == "image_url":
|
||||||
|
m = cast(ChatCompletionImageObject, m)
|
||||||
image_chunk = convert_to_anthropic_image_obj(
|
image_chunk = convert_to_anthropic_image_obj(
|
||||||
m["image_url"]["url"]
|
openai_image_url=m["image_url"]["url"] # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
_anthropic_content_element = AnthropicMessagesImageParam(
|
_anthropic_content_element = AnthropicMessagesImageParam(
|
||||||
|
@ -1362,21 +1372,31 @@ def anthropic_messages_pt(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
anthropic_content_element = add_cache_control_to_content(
|
_content_element = add_cache_control_to_content(
|
||||||
anthropic_content_element=_anthropic_content_element,
|
anthropic_content_element=_anthropic_content_element,
|
||||||
orignal_content_element=m,
|
orignal_content_element=dict(m),
|
||||||
)
|
)
|
||||||
user_content.append(anthropic_content_element)
|
|
||||||
|
if "cache_control" in _content_element:
|
||||||
|
_anthropic_content_element["cache_control"] = (
|
||||||
|
_content_element["cache_control"]
|
||||||
|
)
|
||||||
|
user_content.append(_anthropic_content_element)
|
||||||
elif m.get("type", "") == "text":
|
elif m.get("type", "") == "text":
|
||||||
_anthropic_text_content_element = {
|
m = cast(ChatCompletionTextObject, m)
|
||||||
"type": "text",
|
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||||
"text": m["text"],
|
type="text",
|
||||||
}
|
text=m["text"],
|
||||||
anthropic_content_element = add_cache_control_to_content(
|
|
||||||
anthropic_content_element=_anthropic_text_content_element,
|
|
||||||
orignal_content_element=m,
|
|
||||||
)
|
)
|
||||||
user_content.append(anthropic_content_element)
|
_content_element = add_cache_control_to_content(
|
||||||
|
anthropic_content_element=_anthropic_text_content_element,
|
||||||
|
orignal_content_element=dict(m),
|
||||||
|
)
|
||||||
|
_content_element = cast(
|
||||||
|
AnthropicMessagesTextParam, _content_element
|
||||||
|
)
|
||||||
|
|
||||||
|
user_content.append(_content_element)
|
||||||
elif (
|
elif (
|
||||||
user_message_types_block["role"] == "tool"
|
user_message_types_block["role"] == "tool"
|
||||||
or user_message_types_block["role"] == "function"
|
or user_message_types_block["role"] == "function"
|
||||||
|
@ -1390,12 +1410,17 @@ def anthropic_messages_pt(
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": user_message_types_block["content"],
|
"text": user_message_types_block["content"],
|
||||||
}
|
}
|
||||||
anthropic_content_element = add_cache_control_to_content(
|
_content_element = add_cache_control_to_content(
|
||||||
anthropic_content_element=_anthropic_content_text_element,
|
anthropic_content_element=_anthropic_content_text_element,
|
||||||
orignal_content_element=user_message_types_block,
|
orignal_content_element=dict(user_message_types_block),
|
||||||
)
|
)
|
||||||
|
|
||||||
user_content.append(anthropic_content_element)
|
if "cache_control" in _content_element:
|
||||||
|
_anthropic_content_text_element["cache_control"] = _content_element[
|
||||||
|
"cache_control"
|
||||||
|
]
|
||||||
|
|
||||||
|
user_content.append(_anthropic_content_text_element)
|
||||||
|
|
||||||
msg_i += 1
|
msg_i += 1
|
||||||
|
|
||||||
|
@ -1417,11 +1442,14 @@ def anthropic_messages_pt(
|
||||||
anthropic_message = AnthropicMessagesTextParam(
|
anthropic_message = AnthropicMessagesTextParam(
|
||||||
type="text", text=m.get("text")
|
type="text", text=m.get("text")
|
||||||
)
|
)
|
||||||
anthropic_message = add_cache_control_to_content(
|
_cached_message = add_cache_control_to_content(
|
||||||
anthropic_content_element=anthropic_message,
|
anthropic_content_element=anthropic_message,
|
||||||
orignal_content_element=m,
|
orignal_content_element=dict(m),
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_content.append(
|
||||||
|
cast(AnthropicMessagesTextParam, _cached_message)
|
||||||
)
|
)
|
||||||
assistant_content.append(anthropic_message)
|
|
||||||
elif (
|
elif (
|
||||||
"content" in assistant_content_block
|
"content" in assistant_content_block
|
||||||
and isinstance(assistant_content_block["content"], str)
|
and isinstance(assistant_content_block["content"], str)
|
||||||
|
@ -1430,16 +1458,22 @@ def anthropic_messages_pt(
|
||||||
] # don't pass empty text blocks. anthropic api raises errors.
|
] # don't pass empty text blocks. anthropic api raises errors.
|
||||||
):
|
):
|
||||||
|
|
||||||
_anthropic_text_content_element = {
|
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||||
"type": "text",
|
type="text",
|
||||||
"text": assistant_content_block["content"],
|
text=assistant_content_block["content"],
|
||||||
}
|
|
||||||
|
|
||||||
anthropic_content_element = add_cache_control_to_content(
|
|
||||||
anthropic_content_element=_anthropic_text_content_element,
|
|
||||||
orignal_content_element=assistant_content_block,
|
|
||||||
)
|
)
|
||||||
assistant_content.append(anthropic_content_element)
|
|
||||||
|
_content_element = add_cache_control_to_content(
|
||||||
|
anthropic_content_element=_anthropic_text_content_element,
|
||||||
|
orignal_content_element=dict(assistant_content_block),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "cache_control" in _content_element:
|
||||||
|
_anthropic_text_content_element["cache_control"] = _content_element[
|
||||||
|
"cache_control"
|
||||||
|
]
|
||||||
|
|
||||||
|
assistant_content.append(_anthropic_text_content_element)
|
||||||
|
|
||||||
assistant_tool_calls = assistant_content_block.get("tool_calls")
|
assistant_tool_calls = assistant_content_block.get("tool_calls")
|
||||||
if (
|
if (
|
||||||
|
@ -1566,30 +1600,6 @@ def get_system_prompt(messages):
|
||||||
return system_prompt, messages
|
return system_prompt, messages
|
||||||
|
|
||||||
|
|
||||||
def convert_to_documents(
|
|
||||||
observations: Any,
|
|
||||||
) -> List[MutableMapping]:
|
|
||||||
"""Converts observations into a 'document' dict"""
|
|
||||||
documents: List[MutableMapping] = []
|
|
||||||
if isinstance(observations, str):
|
|
||||||
# strings are turned into a key/value pair and a key of 'output' is added.
|
|
||||||
observations = [{"output": observations}]
|
|
||||||
elif isinstance(observations, Mapping):
|
|
||||||
# single mappings are transformed into a list to simplify the rest of the code.
|
|
||||||
observations = [observations]
|
|
||||||
elif not isinstance(observations, Sequence):
|
|
||||||
# all other types are turned into a key/value pair within a list
|
|
||||||
observations = [{"output": observations}]
|
|
||||||
|
|
||||||
for doc in observations:
|
|
||||||
if not isinstance(doc, Mapping):
|
|
||||||
# types that aren't Mapping are turned into a key/value pair.
|
|
||||||
doc = {"output": doc}
|
|
||||||
documents.append(doc)
|
|
||||||
|
|
||||||
return documents
|
|
||||||
|
|
||||||
|
|
||||||
from litellm.types.llms.cohere import (
|
from litellm.types.llms.cohere import (
|
||||||
CallObject,
|
CallObject,
|
||||||
ChatHistory,
|
ChatHistory,
|
||||||
|
@ -1943,7 +1953,7 @@ def amazon_titan_pt(
|
||||||
def _load_image_from_url(image_url):
|
def _load_image_from_url(image_url):
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except:
|
except Exception:
|
||||||
raise Exception("image conversion failed please run `pip install Pillow`")
|
raise Exception("image conversion failed please run `pip install Pillow`")
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
@ -2008,7 +2018,7 @@ def _gemini_vision_convert_messages(messages: list):
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"gemini image conversion failed please run `pip install Pillow`"
|
"gemini image conversion failed please run `pip install Pillow`"
|
||||||
)
|
)
|
||||||
|
@ -2056,7 +2066,7 @@ def gemini_text_image_pt(messages: list):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import google.generativeai as genai # type: ignore
|
import google.generativeai as genai # type: ignore
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||||
)
|
)
|
||||||
|
@ -2331,7 +2341,7 @@ def _convert_to_bedrock_tool_call_result(
|
||||||
for content in content_list:
|
for content in content_list:
|
||||||
if content["type"] == "text":
|
if content["type"] == "text":
|
||||||
content_str += content["text"]
|
content_str += content["text"]
|
||||||
name = message.get("name", "")
|
message.get("name", "")
|
||||||
id = str(message.get("tool_call_id", str(uuid.uuid4())))
|
id = str(message.get("tool_call_id", str(uuid.uuid4())))
|
||||||
|
|
||||||
tool_result_content_block = BedrockToolResultContentBlock(text=content_str)
|
tool_result_content_block = BedrockToolResultContentBlock(text=content_str)
|
||||||
|
@ -2575,7 +2585,7 @@ def function_call_prompt(messages: list, functions: list):
|
||||||
message["content"] += f""" {function_prompt}"""
|
message["content"] += f""" {function_prompt}"""
|
||||||
function_added_to_prompt = True
|
function_added_to_prompt = True
|
||||||
|
|
||||||
if function_added_to_prompt == False:
|
if function_added_to_prompt is False:
|
||||||
messages.append({"role": "system", "content": f"""{function_prompt}"""})
|
messages.append({"role": "system", "content": f"""{function_prompt}"""})
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
@ -2692,11 +2702,6 @@ def prompt_factory(
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "anthropic_xml":
|
elif custom_llm_provider == "anthropic_xml":
|
||||||
return anthropic_messages_pt_xml(messages=messages)
|
return anthropic_messages_pt_xml(messages=messages)
|
||||||
elif custom_llm_provider == "together_ai":
|
|
||||||
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
|
||||||
return format_prompt_togetherai(
|
|
||||||
messages=messages, prompt_format=prompt_format, chat_template=chat_template
|
|
||||||
)
|
|
||||||
elif custom_llm_provider == "gemini":
|
elif custom_llm_provider == "gemini":
|
||||||
if (
|
if (
|
||||||
model == "gemini-pro-vision"
|
model == "gemini-pro-vision"
|
||||||
|
@ -2810,7 +2815,7 @@ def prompt_factory(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return hf_chat_template(original_model_name, messages)
|
return hf_chat_template(original_model_name, messages)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return default_pt(
|
return default_pt(
|
||||||
messages=messages
|
messages=messages
|
||||||
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||||
|
|
|
@ -61,7 +61,7 @@ async def async_convert_url_to_base64(url: str) -> str:
|
||||||
try:
|
try:
|
||||||
response = await client.get(url, follow_redirects=True)
|
response = await client.get(url, follow_redirects=True)
|
||||||
return _process_image_response(response, url)
|
return _process_image_response(response, url)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"
|
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"
|
||||||
|
|
|
@ -297,7 +297,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
||||||
if "output" in response_data:
|
if "output" in response_data:
|
||||||
try:
|
try:
|
||||||
output_string = "".join(response_data["output"])
|
output_string = "".join(response_data["output"])
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise ReplicateError(
|
raise ReplicateError(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
message="Unable to parse response. Got={}".format(
|
message="Unable to parse response. Got={}".format(
|
||||||
|
@ -344,7 +344,7 @@ async def async_handle_prediction_response_streaming(
|
||||||
if "output" in response_data:
|
if "output" in response_data:
|
||||||
try:
|
try:
|
||||||
output_string = "".join(response_data["output"])
|
output_string = "".join(response_data["output"])
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise ReplicateError(
|
raise ReplicateError(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
message="Unable to parse response. Got={}".format(
|
message="Unable to parse response. Got={}".format(
|
||||||
|
@ -479,7 +479,7 @@ def completion(
|
||||||
else:
|
else:
|
||||||
input_data = {"prompt": prompt, **optional_params}
|
input_data = {"prompt": prompt, **optional_params}
|
||||||
|
|
||||||
if acompletion is not None and acompletion == True:
|
if acompletion is not None and acompletion is True:
|
||||||
return async_completion(
|
return async_completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -513,7 +513,7 @@ def completion(
|
||||||
print_verbose(prediction_url)
|
print_verbose(prediction_url)
|
||||||
|
|
||||||
# Handle the prediction response (streaming or non-streaming)
|
# Handle the prediction response (streaming or non-streaming)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
print_verbose("streaming request")
|
print_verbose("streaming request")
|
||||||
_response = handle_prediction_response_streaming(
|
_response = handle_prediction_response_streaming(
|
||||||
prediction_url, api_key, print_verbose
|
prediction_url, api_key, print_verbose
|
||||||
|
@ -571,7 +571,7 @@ async def async_completion(
|
||||||
http_handler=http_handler,
|
http_handler=http_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
_response = async_handle_prediction_response_streaming(
|
_response = async_handle_prediction_response_streaming(
|
||||||
prediction_url, api_key, print_verbose
|
prediction_url, api_key, print_verbose
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,7 @@ import types
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Union
|
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -112,7 +112,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
## CREDENTIALS ##
|
## CREDENTIALS ##
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
@ -123,7 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
@ -175,7 +175,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
from botocore.auth import SigV4Auth
|
from botocore.auth import SigV4Auth
|
||||||
from botocore.awsrequest import AWSRequest
|
from botocore.awsrequest import AWSRequest
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
|
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
|
||||||
|
@ -244,7 +244,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
hf_model_name = (
|
hf_model_name = (
|
||||||
hf_model_name or model
|
hf_model_name or model
|
||||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
@ -256,10 +256,10 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
hf_model_name=None,
|
hf_model_name=None,
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
|
@ -277,7 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
|
|
||||||
openai_like_chat_completions = DatabricksChatCompletion()
|
openai_like_chat_completions = DatabricksChatCompletion()
|
||||||
inference_params["stream"] = True if stream is True else False
|
inference_params["stream"] = True if stream is True else False
|
||||||
_data = {
|
_data: Dict[str, Any] = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
**inference_params,
|
**inference_params,
|
||||||
|
@ -310,7 +310,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
headers=prepared_request.headers,
|
headers=prepared_request.headers, # type: ignore
|
||||||
custom_endpoint=True,
|
custom_endpoint=True,
|
||||||
custom_llm_provider="sagemaker_chat",
|
custom_llm_provider="sagemaker_chat",
|
||||||
streaming_decoder=custom_stream_decoder, # type: ignore
|
streaming_decoder=custom_stream_decoder, # type: ignore
|
||||||
|
@ -474,7 +474,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
try:
|
try:
|
||||||
sync_response = sync_handler.post(
|
sync_response = sync_handler.post(
|
||||||
url=prepared_request.url,
|
url=prepared_request.url,
|
||||||
headers=prepared_request.headers,
|
headers=prepared_request.headers, # type: ignore
|
||||||
json=_data,
|
json=_data,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
@ -559,7 +559,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
self,
|
self,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
data: str,
|
data: dict,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
|
@ -598,7 +598,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise SagemakerError(status_code=error_code, message=err.response.text)
|
raise SagemakerError(status_code=error_code, message=err.response.text)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise SagemakerError(status_code=408, message="Timeout error occurred.")
|
raise SagemakerError(status_code=408, message="Timeout error occurred.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise SagemakerError(status_code=500, message=str(e))
|
raise SagemakerError(status_code=500, message=str(e))
|
||||||
|
@ -638,7 +638,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
make_call=partial(
|
make_call=partial(
|
||||||
self.make_async_call,
|
self.make_async_call,
|
||||||
api_base=prepared_request.url,
|
api_base=prepared_request.url,
|
||||||
headers=prepared_request.headers,
|
headers=prepared_request.headers, # type: ignore
|
||||||
data=data,
|
data=data,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
),
|
),
|
||||||
|
@ -716,7 +716,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
try:
|
try:
|
||||||
response = await async_handler.post(
|
response = await async_handler.post(
|
||||||
url=prepared_request.url,
|
url=prepared_request.url,
|
||||||
headers=prepared_request.headers,
|
headers=prepared_request.headers, # type: ignore
|
||||||
json=data,
|
json=data,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
@ -794,8 +794,8 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -1032,7 +1032,7 @@ class AWSEventStreamDecoder:
|
||||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
else:
|
else:
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
# Handle or log any unparseable data at the end
|
# Handle or log any unparseable data at the end
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||||
|
|
|
@ -17,6 +17,7 @@ import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from litellm.types.llms.databricks import GenericStreamingChunk
|
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -157,7 +158,7 @@ class MistralTextCompletionConfig:
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
if param == "stream" and value == True:
|
if param == "stream" and value is True:
|
||||||
optional_params["stream"] = value
|
optional_params["stream"] = value
|
||||||
if param == "stop":
|
if param == "stop":
|
||||||
optional_params["stop"] = value
|
optional_params["stop"] = value
|
||||||
|
@ -249,7 +250,7 @@ class CodestralTextCompletion(BaseLLM):
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response: TextCompletionResponse,
|
model_response: TextCompletionResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
logging_obj: LiteLLMLogging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: Union[dict, str],
|
data: Union[dict, str],
|
||||||
|
@ -273,7 +274,7 @@ class CodestralTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except Exception:
|
||||||
raise TextCompletionCodestralError(message=response.text, status_code=422)
|
raise TextCompletionCodestralError(message=response.text, status_code=422)
|
||||||
|
|
||||||
_original_choices = completion_response.get("choices", [])
|
_original_choices = completion_response.get("choices", [])
|
||||||
|
|
|
@ -176,7 +176,7 @@ class VertexAIConfig:
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if (
|
if (
|
||||||
param == "stream" and value == True
|
param == "stream" and value is True
|
||||||
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
||||||
optional_params["stream"] = value
|
optional_params["stream"] = value
|
||||||
if param == "n":
|
if param == "n":
|
||||||
|
@ -1313,7 +1313,6 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
is_finished = False
|
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
usage: Optional[ChatCompletionUsageBlock] = None
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
|
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
|
||||||
|
|
|
@ -268,7 +268,7 @@ def completion(
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import vertexai
|
import vertexai
|
||||||
except:
|
except Exception:
|
||||||
raise VertexAIError(
|
raise VertexAIError(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
||||||
|
|
|
@ -5,7 +5,7 @@ import time
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Literal, Optional, Union
|
from typing import Any, Callable, List, Literal, Optional, Union, cast
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -25,7 +25,12 @@ from litellm.types.files import (
|
||||||
is_gemini_1_5_accepted_file_type,
|
is_gemini_1_5_accepted_file_type,
|
||||||
is_video_file_type,
|
is_video_file_type,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionAssistantMessage,
|
||||||
|
ChatCompletionImageObject,
|
||||||
|
ChatCompletionTextObject,
|
||||||
|
)
|
||||||
from litellm.types.llms.vertex_ai import *
|
from litellm.types.llms.vertex_ai import *
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
@ -150,30 +155,34 @@ def _gemini_convert_messages_with_history(
|
||||||
while (
|
while (
|
||||||
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||||
):
|
):
|
||||||
if messages[msg_i]["content"] is not None and isinstance(
|
_message_content = messages[msg_i].get("content")
|
||||||
messages[msg_i]["content"], list
|
if _message_content is not None and isinstance(_message_content, list):
|
||||||
):
|
|
||||||
_parts: List[PartType] = []
|
_parts: List[PartType] = []
|
||||||
for element in messages[msg_i]["content"]: # type: ignore
|
for element in _message_content:
|
||||||
if isinstance(element, dict):
|
if (
|
||||||
if element["type"] == "text" and len(element["text"]) > 0: # type: ignore
|
element["type"] == "text"
|
||||||
_part = PartType(text=element["text"]) # type: ignore
|
and "text" in element
|
||||||
_parts.append(_part)
|
and len(element["text"]) > 0
|
||||||
elif element["type"] == "image_url":
|
):
|
||||||
img_element: ChatCompletionImageObject = element # type: ignore
|
element = cast(ChatCompletionTextObject, element)
|
||||||
if isinstance(img_element["image_url"], dict):
|
_part = PartType(text=element["text"])
|
||||||
image_url = img_element["image_url"]["url"]
|
_parts.append(_part)
|
||||||
else:
|
elif element["type"] == "image_url":
|
||||||
image_url = img_element["image_url"]
|
element = cast(ChatCompletionImageObject, element)
|
||||||
_part = _process_gemini_image(image_url=image_url)
|
img_element = element
|
||||||
_parts.append(_part) # type: ignore
|
if isinstance(img_element["image_url"], dict):
|
||||||
|
image_url = img_element["image_url"]["url"]
|
||||||
|
else:
|
||||||
|
image_url = img_element["image_url"]
|
||||||
|
_part = _process_gemini_image(image_url=image_url)
|
||||||
|
_parts.append(_part)
|
||||||
user_content.extend(_parts)
|
user_content.extend(_parts)
|
||||||
elif (
|
elif (
|
||||||
messages[msg_i]["content"] is not None
|
_message_content is not None
|
||||||
and isinstance(messages[msg_i]["content"], str)
|
and isinstance(_message_content, str)
|
||||||
and len(messages[msg_i]["content"]) > 0 # type: ignore
|
and len(_message_content) > 0
|
||||||
):
|
):
|
||||||
_part = PartType(text=messages[msg_i]["content"]) # type: ignore
|
_part = PartType(text=_message_content)
|
||||||
user_content.append(_part)
|
user_content.append(_part)
|
||||||
|
|
||||||
msg_i += 1
|
msg_i += 1
|
||||||
|
@ -201,22 +210,21 @@ def _gemini_convert_messages_with_history(
|
||||||
else:
|
else:
|
||||||
msg_dict = messages[msg_i] # type: ignore
|
msg_dict = messages[msg_i] # type: ignore
|
||||||
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
|
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
|
||||||
if assistant_msg.get("content", None) is not None and isinstance(
|
_message_content = assistant_msg.get("content", None)
|
||||||
assistant_msg["content"], list
|
if _message_content is not None and isinstance(_message_content, list):
|
||||||
):
|
|
||||||
_parts = []
|
_parts = []
|
||||||
for element in assistant_msg["content"]:
|
for element in _message_content:
|
||||||
if isinstance(element, dict):
|
if isinstance(element, dict):
|
||||||
if element["type"] == "text":
|
if element["type"] == "text":
|
||||||
_part = PartType(text=element["text"]) # type: ignore
|
_part = PartType(text=element["text"])
|
||||||
_parts.append(_part)
|
_parts.append(_part)
|
||||||
assistant_content.extend(_parts)
|
assistant_content.extend(_parts)
|
||||||
elif (
|
elif (
|
||||||
assistant_msg.get("content", None) is not None
|
_message_content is not None
|
||||||
and isinstance(assistant_msg["content"], str)
|
and isinstance(_message_content, str)
|
||||||
and assistant_msg["content"]
|
and _message_content
|
||||||
):
|
):
|
||||||
assistant_text = assistant_msg["content"] # either string or none
|
assistant_text = _message_content # either string or none
|
||||||
assistant_content.append(PartType(text=assistant_text)) # type: ignore
|
assistant_content.append(PartType(text=assistant_text)) # type: ignore
|
||||||
|
|
||||||
## HANDLE ASSISTANT FUNCTION CALL
|
## HANDLE ASSISTANT FUNCTION CALL
|
||||||
|
@ -256,7 +264,9 @@ def _gemini_convert_messages_with_history(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
def _get_client_cache_key(
|
||||||
|
model: str, vertex_project: Optional[str], vertex_location: Optional[str]
|
||||||
|
):
|
||||||
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
|
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
|
||||||
return _cache_key
|
return _cache_key
|
||||||
|
|
||||||
|
@ -294,7 +304,7 @@ def completion(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import vertexai
|
import vertexai
|
||||||
except:
|
except Exception:
|
||||||
raise VertexAIError(
|
raise VertexAIError(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
|
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
|
||||||
|
@ -339,6 +349,8 @@ def completion(
|
||||||
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
|
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
|
||||||
|
|
||||||
if _vertex_llm_model_object is None:
|
if _vertex_llm_model_object is None:
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||||
import google.oauth2.service_account
|
import google.oauth2.service_account
|
||||||
|
|
||||||
|
@ -356,7 +368,9 @@ def completion(
|
||||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||||
)
|
)
|
||||||
vertexai.init(
|
vertexai.init(
|
||||||
project=vertex_project, location=vertex_location, credentials=creds
|
project=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
|
credentials=cast(Credentials, creds),
|
||||||
)
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
|
@ -391,7 +405,6 @@ def completion(
|
||||||
|
|
||||||
request_str = ""
|
request_str = ""
|
||||||
response_obj = None
|
response_obj = None
|
||||||
async_client = None
|
|
||||||
instances = None
|
instances = None
|
||||||
client_options = {
|
client_options = {
|
||||||
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
|
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
|
||||||
|
@ -400,7 +413,7 @@ def completion(
|
||||||
model in litellm.vertex_language_models
|
model in litellm.vertex_language_models
|
||||||
or model in litellm.vertex_vision_models
|
or model in litellm.vertex_vision_models
|
||||||
):
|
):
|
||||||
llm_model = _vertex_llm_model_object or GenerativeModel(model)
|
llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
|
||||||
mode = "vision"
|
mode = "vision"
|
||||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||||
elif model in litellm.vertex_chat_models:
|
elif model in litellm.vertex_chat_models:
|
||||||
|
@ -459,7 +472,6 @@ def completion(
|
||||||
"model_response": model_response,
|
"model_response": model_response,
|
||||||
"encoding": encoding,
|
"encoding": encoding,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"request_str": request_str,
|
|
||||||
"print_verbose": print_verbose,
|
"print_verbose": print_verbose,
|
||||||
"client_options": client_options,
|
"client_options": client_options,
|
||||||
"instances": instances,
|
"instances": instances,
|
||||||
|
@ -474,6 +486,7 @@ def completion(
|
||||||
|
|
||||||
return async_completion(**data)
|
return async_completion(**data)
|
||||||
|
|
||||||
|
completion_response = None
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
@ -529,7 +542,7 @@ def completion(
|
||||||
# Check if it's a RepeatedComposite instance
|
# Check if it's a RepeatedComposite instance
|
||||||
for key, val in function_call.args.items():
|
for key, val in function_call.args.items():
|
||||||
if isinstance(
|
if isinstance(
|
||||||
val, proto.marshal.collections.repeated.RepeatedComposite
|
val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
|
||||||
):
|
):
|
||||||
# If so, convert to list
|
# If so, convert to list
|
||||||
args_dict[key] = [v for v in val]
|
args_dict[key] = [v for v in val]
|
||||||
|
@ -560,9 +573,9 @@ def completion(
|
||||||
optional_params["tools"] = tools
|
optional_params["tools"] = tools
|
||||||
elif mode == "chat":
|
elif mode == "chat":
|
||||||
chat = llm_model.start_chat()
|
chat = llm_model.start_chat()
|
||||||
request_str += f"chat = llm_model.start_chat()\n"
|
request_str += "chat = llm_model.start_chat()\n"
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
||||||
# we handle this by removing 'stream' from optional params and sending the request
|
# we handle this by removing 'stream' from optional params and sending the request
|
||||||
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
||||||
|
@ -597,7 +610,7 @@ def completion(
|
||||||
)
|
)
|
||||||
completion_response = chat.send_message(prompt, **optional_params).text
|
completion_response = chat.send_message(prompt, **optional_params).text
|
||||||
elif mode == "text":
|
elif mode == "text":
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
optional_params.pop(
|
optional_params.pop(
|
||||||
"stream", None
|
"stream", None
|
||||||
) # See note above on handling streaming for vertex ai
|
) # See note above on handling streaming for vertex ai
|
||||||
|
@ -632,6 +645,12 @@ def completion(
|
||||||
"""
|
"""
|
||||||
Vertex AI Model Garden
|
Vertex AI Model Garden
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if vertex_project is None or vertex_location is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Vertex project and location are required for custom endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -661,13 +680,17 @@ def completion(
|
||||||
and "\nOutput:\n" in completion_response
|
and "\nOutput:\n" in completion_response
|
||||||
):
|
):
|
||||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
response = TextStreamer(completion_response)
|
response = TextStreamer(completion_response)
|
||||||
return response
|
return response
|
||||||
elif mode == "private":
|
elif mode == "private":
|
||||||
"""
|
"""
|
||||||
Vertex AI Model Garden deployed on private endpoint
|
Vertex AI Model Garden deployed on private endpoint
|
||||||
"""
|
"""
|
||||||
|
if instances is None:
|
||||||
|
raise ValueError("instances are required for private endpoint")
|
||||||
|
if llm_model is None:
|
||||||
|
raise ValueError("Unable to pick client for private endpoint")
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -686,7 +709,7 @@ def completion(
|
||||||
and "\nOutput:\n" in completion_response
|
and "\nOutput:\n" in completion_response
|
||||||
):
|
):
|
||||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
response = TextStreamer(completion_response)
|
response = TextStreamer(completion_response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -715,7 +738,7 @@ def completion(
|
||||||
else:
|
else:
|
||||||
# init prompt tokens
|
# init prompt tokens
|
||||||
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
||||||
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
prompt_tokens, completion_tokens, _ = 0, 0, 0
|
||||||
if response_obj is not None:
|
if response_obj is not None:
|
||||||
if hasattr(response_obj, "usage_metadata") and hasattr(
|
if hasattr(response_obj, "usage_metadata") and hasattr(
|
||||||
response_obj.usage_metadata, "prompt_token_count"
|
response_obj.usage_metadata, "prompt_token_count"
|
||||||
|
@ -771,11 +794,13 @@ async def async_completion(
|
||||||
try:
|
try:
|
||||||
import proto # type: ignore
|
import proto # type: ignore
|
||||||
|
|
||||||
|
response_obj = None
|
||||||
|
completion_response = None
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
stream = optional_params.pop("stream", False)
|
optional_params.pop("stream", False)
|
||||||
|
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
|
||||||
|
@ -817,7 +842,7 @@ async def async_completion(
|
||||||
# Check if it's a RepeatedComposite instance
|
# Check if it's a RepeatedComposite instance
|
||||||
for key, val in function_call.args.items():
|
for key, val in function_call.args.items():
|
||||||
if isinstance(
|
if isinstance(
|
||||||
val, proto.marshal.collections.repeated.RepeatedComposite
|
val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
|
||||||
):
|
):
|
||||||
# If so, convert to list
|
# If so, convert to list
|
||||||
args_dict[key] = [v for v in val]
|
args_dict[key] = [v for v in val]
|
||||||
|
@ -880,6 +905,11 @@ async def async_completion(
|
||||||
"""
|
"""
|
||||||
from google.cloud import aiplatform # type: ignore
|
from google.cloud import aiplatform # type: ignore
|
||||||
|
|
||||||
|
if vertex_project is None or vertex_location is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Vertex project and location are required for custom endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -953,7 +983,7 @@ async def async_completion(
|
||||||
else:
|
else:
|
||||||
# init prompt tokens
|
# init prompt tokens
|
||||||
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
||||||
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
prompt_tokens, completion_tokens, _ = 0, 0, 0
|
||||||
if response_obj is not None and (
|
if response_obj is not None and (
|
||||||
hasattr(response_obj, "usage_metadata")
|
hasattr(response_obj, "usage_metadata")
|
||||||
and hasattr(response_obj.usage_metadata, "prompt_token_count")
|
and hasattr(response_obj.usage_metadata, "prompt_token_count")
|
||||||
|
@ -1001,6 +1031,7 @@ async def async_streaming(
|
||||||
"""
|
"""
|
||||||
Add support for async streaming calls for gemini-pro
|
Add support for async streaming calls for gemini-pro
|
||||||
"""
|
"""
|
||||||
|
response: Any = None
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
stream = optional_params.pop("stream")
|
stream = optional_params.pop("stream")
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
|
@ -1065,6 +1096,11 @@ async def async_streaming(
|
||||||
elif mode == "custom":
|
elif mode == "custom":
|
||||||
from google.cloud import aiplatform # type: ignore
|
from google.cloud import aiplatform # type: ignore
|
||||||
|
|
||||||
|
if vertex_project is None or vertex_location is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Vertex project and location are required for custom endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -1102,6 +1138,8 @@ async def async_streaming(
|
||||||
response = TextStreamer(completion_response)
|
response = TextStreamer(completion_response)
|
||||||
|
|
||||||
elif mode == "private":
|
elif mode == "private":
|
||||||
|
if instances is None:
|
||||||
|
raise ValueError("Instances are required for private endpoint")
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
_ = instances[0].pop("stream", None)
|
_ = instances[0].pop("stream", None)
|
||||||
request_str += f"llm_model.predict_async(instances={instances})\n"
|
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||||
|
@ -1118,6 +1156,9 @@ async def async_streaming(
|
||||||
if stream:
|
if stream:
|
||||||
response = TextStreamer(completion_response)
|
response = TextStreamer(completion_response)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise ValueError("Unable to generate response")
|
||||||
|
|
||||||
logging_obj.post_call(input=prompt, api_key=None, original_response=response)
|
logging_obj.post_call(input=prompt, api_key=None, original_response=response)
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import types
|
import types
|
||||||
from typing import Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -53,7 +53,7 @@ class VertexEmbedding(VertexBase):
|
||||||
gemini_api_key: Optional[str] = None,
|
gemini_api_key: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if aembedding == True:
|
if aembedding is True:
|
||||||
return self.async_embedding(
|
return self.async_embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
|
|
@ -45,8 +45,8 @@ def completion(
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -83,7 +83,7 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
return iter(outputs)
|
return iter(outputs)
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -144,10 +144,7 @@ def batch_completions(
|
||||||
llm, SamplingParams = validate_environment(model=model)
|
llm, SamplingParams = validate_environment(model=model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_str = str(e)
|
error_str = str(e)
|
||||||
if "data parallel group is already initialized" in error_str:
|
raise VLLMError(status_code=0, message=error_str)
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise VLLMError(status_code=0, message=error_str)
|
|
||||||
sampling_params = SamplingParams(**optional_params)
|
sampling_params = SamplingParams(**optional_params)
|
||||||
prompts = []
|
prompts = []
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
|
|
134
litellm/main.py
134
litellm/main.py
|
@ -106,6 +106,7 @@ from .llms.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
map_system_message_pt,
|
map_system_message_pt,
|
||||||
|
ollama_pt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
stringify_json_tool_call_content,
|
stringify_json_tool_call_content,
|
||||||
)
|
)
|
||||||
|
@ -150,7 +151,6 @@ from .types.utils import (
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
CustomStreamWrapper,
|
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
Message,
|
Message,
|
||||||
|
@ -159,8 +159,6 @@ from litellm.utils import (
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
TextCompletionStreamWrapper,
|
TextCompletionStreamWrapper,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
get_secret,
|
|
||||||
read_config_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
|
@ -214,7 +212,7 @@ class LiteLLM:
|
||||||
class Chat:
|
class Chat:
|
||||||
def __init__(self, params, router_obj: Optional[Any]):
|
def __init__(self, params, router_obj: Optional[Any]):
|
||||||
self.params = params
|
self.params = params
|
||||||
if self.params.get("acompletion", False) == True:
|
if self.params.get("acompletion", False) is True:
|
||||||
self.params.pop("acompletion")
|
self.params.pop("acompletion")
|
||||||
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
|
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
|
||||||
self.params, router_obj=router_obj
|
self.params, router_obj=router_obj
|
||||||
|
@ -837,10 +835,10 @@ def completion(
|
||||||
model_response = ModelResponse()
|
model_response = ModelResponse()
|
||||||
setattr(model_response, "usage", litellm.Usage())
|
setattr(model_response, "usage", litellm.Usage())
|
||||||
if (
|
if (
|
||||||
kwargs.get("azure", False) == True
|
kwargs.get("azure", False) is True
|
||||||
): # don't remove flag check, to remain backwards compatible for repos like Codium
|
): # don't remove flag check, to remain backwards compatible for repos like Codium
|
||||||
custom_llm_provider = "azure"
|
custom_llm_provider = "azure"
|
||||||
if deployment_id != None: # azure llms
|
if deployment_id is not None: # azure llms
|
||||||
model = deployment_id
|
model = deployment_id
|
||||||
custom_llm_provider = "azure"
|
custom_llm_provider = "azure"
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||||
|
@ -1156,7 +1154,7 @@ def completion(
|
||||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion is True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -1278,7 +1276,7 @@ def completion(
|
||||||
if (
|
if (
|
||||||
len(messages) > 0
|
len(messages) > 0
|
||||||
and "content" in messages[0]
|
and "content" in messages[0]
|
||||||
and type(messages[0]["content"]) == list
|
and isinstance(messages[0]["content"], list)
|
||||||
):
|
):
|
||||||
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
|
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
|
||||||
# https://platform.openai.com/docs/api-reference/completions/create
|
# https://platform.openai.com/docs/api-reference/completions/create
|
||||||
|
@ -1304,16 +1302,16 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
optional_params.get("stream", False) == False
|
optional_params.get("stream", False) is False
|
||||||
and acompletion == False
|
and acompletion is False
|
||||||
and text_completion == False
|
and text_completion is False
|
||||||
):
|
):
|
||||||
# convert to chat completion response
|
# convert to chat completion response
|
||||||
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
|
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
|
||||||
response_object=_response, model_response_object=model_response
|
response_object=_response, model_response_object=model_response
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion is True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -1519,7 +1517,7 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) == True:
|
if optional_params.get("stream", False) is True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -1566,7 +1564,7 @@ def completion(
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
@ -1575,7 +1573,7 @@ def completion(
|
||||||
original_response=model_response,
|
original_response=model_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion is True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -1654,7 +1652,7 @@ def completion(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion is True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -1691,7 +1689,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
response,
|
response,
|
||||||
|
@ -1700,7 +1698,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion is True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -1740,7 +1738,7 @@ def completion(
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response,
|
model_response,
|
||||||
|
@ -1788,7 +1786,7 @@ def completion(
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response,
|
model_response,
|
||||||
|
@ -1836,7 +1834,7 @@ def completion(
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response,
|
model_response,
|
||||||
|
@ -1875,7 +1873,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response,
|
model_response,
|
||||||
|
@ -1916,7 +1914,7 @@ def completion(
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] is True
|
||||||
and acompletion is False
|
and acompletion is False
|
||||||
):
|
):
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
@ -1943,7 +1941,7 @@ def completion(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response,
|
model_response,
|
||||||
|
@ -2095,7 +2093,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
# fake palm streaming
|
# fake palm streaming
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# fake streaming for palm
|
# fake streaming for palm
|
||||||
resp_string = model_response["choices"][0]["message"]["content"]
|
resp_string = model_response["choices"][0]["message"]["content"]
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
|
@ -2390,7 +2388,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response,
|
model_response,
|
||||||
|
@ -2527,7 +2525,7 @@ def completion(
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] is True
|
||||||
and not isinstance(response, CustomStreamWrapper)
|
and not isinstance(response, CustomStreamWrapper)
|
||||||
):
|
):
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
@ -2563,7 +2561,7 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params and optional_params["stream"] == True
|
"stream" in optional_params and optional_params["stream"] is True
|
||||||
): ## [BETA]
|
): ## [BETA]
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
|
@ -2587,38 +2585,38 @@ def completion(
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt = custom_prompt(
|
ollama_prompt = custom_prompt(
|
||||||
role_dict=model_prompt_details["roles"],
|
role_dict=model_prompt_details["roles"],
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(
|
modified_prompt = ollama_pt(model=model, messages=messages)
|
||||||
model=model,
|
if isinstance(modified_prompt, dict):
|
||||||
messages=messages,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
)
|
|
||||||
if isinstance(prompt, dict):
|
|
||||||
# for multimode models - ollama/llava prompt_factory returns a dict {
|
# for multimode models - ollama/llava prompt_factory returns a dict {
|
||||||
# "prompt": prompt,
|
# "prompt": prompt,
|
||||||
# "images": images
|
# "images": images
|
||||||
# }
|
# }
|
||||||
prompt, images = prompt["prompt"], prompt["images"]
|
ollama_prompt, images = (
|
||||||
|
modified_prompt["prompt"],
|
||||||
|
modified_prompt["images"],
|
||||||
|
)
|
||||||
optional_params["images"] = images
|
optional_params["images"] = images
|
||||||
|
else:
|
||||||
|
ollama_prompt = modified_prompt
|
||||||
## LOGGING
|
## LOGGING
|
||||||
generator = ollama.get_ollama_response(
|
generator = ollama.get_ollama_response(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=ollama_prompt,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
if acompletion is True or optional_params.get("stream", False) == True:
|
if acompletion is True or optional_params.get("stream", False) is True:
|
||||||
return generator
|
return generator
|
||||||
|
|
||||||
response = generator
|
response = generator
|
||||||
|
@ -2701,7 +2699,7 @@ def completion(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
response,
|
response,
|
||||||
|
@ -2710,7 +2708,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion is True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -2743,7 +2741,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
if inspect.isgenerator(model_response) or (
|
if inspect.isgenerator(model_response) or (
|
||||||
"stream" in optional_params and optional_params["stream"] == True
|
"stream" in optional_params and optional_params["stream"] is True
|
||||||
):
|
):
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
|
@ -2771,7 +2769,7 @@ def completion(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
if stream == True: ## [BETA]
|
if stream is True: ## [BETA]
|
||||||
# Fake streaming for petals
|
# Fake streaming for petals
|
||||||
resp_string = model_response["choices"][0]["message"]["content"]
|
resp_string = model_response["choices"][0]["message"]["content"]
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
|
@ -2786,7 +2784,7 @@ def completion(
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
url = litellm.api_base or api_base or ""
|
url = litellm.api_base or api_base or ""
|
||||||
if url == None or url == "":
|
if url is None or url == "":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"api_base not set. Set api_base or litellm.api_base for custom endpoints"
|
"api_base not set. Set api_base or litellm.api_base for custom endpoints"
|
||||||
)
|
)
|
||||||
|
@ -3145,10 +3143,10 @@ def batch_completion_models(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
result = future.result()
|
result = future.result()
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# if model 1 fails, continue with response from model 2, model3
|
# if model 1 fails, continue with response from model 2, model3
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"\n\ngot an exception, ignoring, removing from futures"
|
"\n\ngot an exception, ignoring, removing from futures"
|
||||||
)
|
)
|
||||||
print_verbose(futures)
|
print_verbose(futures)
|
||||||
new_futures = {}
|
new_futures = {}
|
||||||
|
@ -3189,9 +3187,6 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
|
||||||
# ANSI escape codes for colored output
|
# ANSI escape codes for colored output
|
||||||
GREEN = "\033[92m"
|
|
||||||
RED = "\033[91m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
if "model" in kwargs:
|
if "model" in kwargs:
|
||||||
kwargs.pop("model")
|
kwargs.pop("model")
|
||||||
|
@ -3520,7 +3515,7 @@ def embedding(
|
||||||
|
|
||||||
if api_base is None:
|
if api_base is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
|
"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
|
||||||
)
|
)
|
||||||
|
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
|
@ -4106,7 +4101,6 @@ def text_completion(
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
global print_verbose
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -4136,7 +4130,7 @@ def text_completion(
|
||||||
Your example of how to use this function goes here.
|
Your example of how to use this function goes here.
|
||||||
"""
|
"""
|
||||||
if "engine" in kwargs:
|
if "engine" in kwargs:
|
||||||
if model == None:
|
if model is None:
|
||||||
# only use engine when model not passed
|
# only use engine when model not passed
|
||||||
model = kwargs["engine"]
|
model = kwargs["engine"]
|
||||||
kwargs.pop("engine")
|
kwargs.pop("engine")
|
||||||
|
@ -4189,18 +4183,18 @@ def text_completion(
|
||||||
|
|
||||||
if custom_llm_provider == "huggingface":
|
if custom_llm_provider == "huggingface":
|
||||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||||
if echo == True:
|
if echo is True:
|
||||||
# for tgi llms
|
# for tgi llms
|
||||||
if "top_n_tokens" not in kwargs:
|
if "top_n_tokens" not in kwargs:
|
||||||
kwargs["top_n_tokens"] = 3
|
kwargs["top_n_tokens"] = 3
|
||||||
|
|
||||||
# processing prompt - users can pass raw tokens to OpenAI Completion()
|
# processing prompt - users can pass raw tokens to OpenAI Completion()
|
||||||
if type(prompt) == list:
|
if isinstance(prompt, list):
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
|
||||||
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
|
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
|
||||||
## if it's a 2d list - each element in the list is a text_completion() request
|
## if it's a 2d list - each element in the list is a text_completion() request
|
||||||
if len(prompt) > 0 and type(prompt[0]) == list:
|
if len(prompt) > 0 and isinstance(prompt[0], list):
|
||||||
responses = [None for x in prompt] # init responses
|
responses = [None for x in prompt] # init responses
|
||||||
|
|
||||||
def process_prompt(i, individual_prompt):
|
def process_prompt(i, individual_prompt):
|
||||||
|
@ -4299,7 +4293,7 @@ def text_completion(
|
||||||
raw_response = response._hidden_params.get("original_response", None)
|
raw_response = response._hidden_params.get("original_response", None)
|
||||||
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM non blocking exception: {e}")
|
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
|
||||||
|
|
||||||
if isinstance(response, TextCompletionResponse):
|
if isinstance(response, TextCompletionResponse):
|
||||||
return response
|
return response
|
||||||
|
@ -4813,12 +4807,12 @@ def transcription(
|
||||||
Allows router to load balance between them
|
Allows router to load balance between them
|
||||||
"""
|
"""
|
||||||
atranscription = kwargs.get("atranscription", False)
|
atranscription = kwargs.get("atranscription", False)
|
||||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
kwargs.get("litellm_call_id", None)
|
||||||
logger_fn = kwargs.get("logger_fn", None)
|
kwargs.get("logger_fn", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
kwargs.get("proxy_server_request", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
kwargs.get("metadata", {})
|
||||||
tags = kwargs.pop("tags", [])
|
kwargs.pop("tags", [])
|
||||||
|
|
||||||
drop_params = kwargs.get("drop_params", None)
|
drop_params = kwargs.get("drop_params", None)
|
||||||
client: Optional[
|
client: Optional[
|
||||||
|
@ -4996,7 +4990,7 @@ def speech(
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
metadata = kwargs.get("metadata", {})
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||||
tags = kwargs.pop("tags", [])
|
kwargs.pop("tags", [])
|
||||||
|
|
||||||
optional_params = {}
|
optional_params = {}
|
||||||
if response_format is not None:
|
if response_format is not None:
|
||||||
|
@ -5345,12 +5339,12 @@ def print_verbose(print_statement):
|
||||||
verbose_logger.debug(print_statement)
|
verbose_logger.debug(print_statement)
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def config_completion(**kwargs):
|
def config_completion(**kwargs):
|
||||||
if litellm.config_path != None:
|
if litellm.config_path is not None:
|
||||||
config_args = read_config_args(litellm.config_path)
|
config_args = read_config_args(litellm.config_path)
|
||||||
# overwrite any args passed in with config args
|
# overwrite any args passed in with config args
|
||||||
return completion(**kwargs, **config_args)
|
return completion(**kwargs, **config_args)
|
||||||
|
@ -5408,16 +5402,18 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
|
||||||
response["choices"][0]["text"] = combined_content
|
response["choices"][0]["text"] = combined_content
|
||||||
|
|
||||||
if len(combined_content) > 0:
|
if len(combined_content) > 0:
|
||||||
completion_output = combined_content
|
pass
|
||||||
else:
|
else:
|
||||||
completion_output = ""
|
pass
|
||||||
# # Update usage information if needed
|
# # Update usage information if needed
|
||||||
try:
|
try:
|
||||||
response["usage"]["prompt_tokens"] = token_counter(
|
response["usage"]["prompt_tokens"] = token_counter(
|
||||||
model=model, messages=messages
|
model=model, messages=messages
|
||||||
)
|
)
|
||||||
except: # don't allow this failing to block a complete streaming response from being returned
|
except (
|
||||||
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
|
Exception
|
||||||
|
): # don't allow this failing to block a complete streaming response from being returned
|
||||||
|
print_verbose("token_counter failed, assuming prompt tokens is 0")
|
||||||
response["usage"]["prompt_tokens"] = 0
|
response["usage"]["prompt_tokens"] = 0
|
||||||
response["usage"]["completion_tokens"] = token_counter(
|
response["usage"]["completion_tokens"] = token_counter(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -128,14 +128,14 @@ class LiteLLMBase(BaseModel):
|
||||||
def json(self, **kwargs): # type: ignore
|
def json(self, **kwargs): # type: ignore
|
||||||
try:
|
try:
|
||||||
return self.model_dump(**kwargs) # noqa
|
return self.model_dump(**kwargs) # noqa
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# if using pydantic v1
|
# if using pydantic v1
|
||||||
return self.dict(**kwargs)
|
return self.dict(**kwargs)
|
||||||
|
|
||||||
def fields_set(self):
|
def fields_set(self):
|
||||||
try:
|
try:
|
||||||
return self.model_fields_set # noqa
|
return self.model_fields_set # noqa
|
||||||
except:
|
except Exception:
|
||||||
# if using pydantic v1
|
# if using pydantic v1
|
||||||
return self.__fields_set__
|
return self.__fields_set__
|
||||||
|
|
||||||
|
|
|
@ -1,202 +0,0 @@
|
||||||
def add_new_model():
|
|
||||||
import streamlit as st
|
|
||||||
import json, requests, uuid
|
|
||||||
|
|
||||||
model_name = st.text_input(
|
|
||||||
"Model Name - user-facing model name", placeholder="gpt-3.5-turbo"
|
|
||||||
)
|
|
||||||
st.subheader("LiteLLM Params")
|
|
||||||
litellm_model_name = st.text_input(
|
|
||||||
"Model", placeholder="azure/gpt-35-turbo-us-east"
|
|
||||||
)
|
|
||||||
litellm_api_key = st.text_input("API Key")
|
|
||||||
litellm_api_base = st.text_input(
|
|
||||||
"API Base",
|
|
||||||
placeholder="https://my-endpoint.openai.azure.com",
|
|
||||||
)
|
|
||||||
litellm_api_version = st.text_input("API Version", placeholder="2023-07-01-preview")
|
|
||||||
litellm_params = json.loads(
|
|
||||||
st.text_area(
|
|
||||||
"Additional Litellm Params (JSON dictionary). [See all possible inputs](https://github.com/BerriAI/litellm/blob/3f15d7230fe8e7492c95a752963e7fbdcaf7bf98/litellm/main.py#L293)",
|
|
||||||
value={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
st.subheader("Model Info")
|
|
||||||
mode_options = ("completion", "embedding", "image generation")
|
|
||||||
mode_selected = st.selectbox("Mode", mode_options)
|
|
||||||
model_info = json.loads(
|
|
||||||
st.text_area(
|
|
||||||
"Additional Model Info (JSON dictionary)",
|
|
||||||
value={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if st.button("Submit"):
|
|
||||||
try:
|
|
||||||
model_info = {
|
|
||||||
"model_name": model_name,
|
|
||||||
"litellm_params": {
|
|
||||||
"model": litellm_model_name,
|
|
||||||
"api_key": litellm_api_key,
|
|
||||||
"api_base": litellm_api_base,
|
|
||||||
"api_version": litellm_api_version,
|
|
||||||
},
|
|
||||||
"model_info": {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"mode": mode_selected,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
# Make the POST request to the specified URL
|
|
||||||
complete_url = ""
|
|
||||||
if st.session_state["api_url"].endswith("/"):
|
|
||||||
complete_url = f"{st.session_state['api_url']}model/new"
|
|
||||||
else:
|
|
||||||
complete_url = f"{st.session_state['api_url']}/model/new"
|
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {st.session_state['proxy_key']}"}
|
|
||||||
response = requests.post(complete_url, json=model_info, headers=headers)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
st.success("Model added successfully!")
|
|
||||||
else:
|
|
||||||
st.error(f"Failed to add model. Status code: {response.status_code}")
|
|
||||||
|
|
||||||
st.success("Form submitted successfully!")
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
|
||||||
import streamlit as st
|
|
||||||
import requests
|
|
||||||
|
|
||||||
# Check if the necessary configuration is available
|
|
||||||
if (
|
|
||||||
st.session_state.get("api_url", None) is not None
|
|
||||||
and st.session_state.get("proxy_key", None) is not None
|
|
||||||
):
|
|
||||||
# Make the GET request
|
|
||||||
try:
|
|
||||||
complete_url = ""
|
|
||||||
if isinstance(st.session_state["api_url"], str) and st.session_state[
|
|
||||||
"api_url"
|
|
||||||
].endswith("/"):
|
|
||||||
complete_url = f"{st.session_state['api_url']}models"
|
|
||||||
else:
|
|
||||||
complete_url = f"{st.session_state['api_url']}/models"
|
|
||||||
response = requests.get(
|
|
||||||
complete_url,
|
|
||||||
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
|
|
||||||
)
|
|
||||||
# Check if the request was successful
|
|
||||||
if response.status_code == 200:
|
|
||||||
models = response.json()
|
|
||||||
st.write(models) # or st.json(models) to pretty print the JSON
|
|
||||||
else:
|
|
||||||
st.error(f"Failed to get models. Status code: {response.status_code}")
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"An error occurred while requesting models: {e}")
|
|
||||||
else:
|
|
||||||
st.warning(
|
|
||||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_key():
|
|
||||||
import streamlit as st
|
|
||||||
import json, requests, uuid
|
|
||||||
|
|
||||||
if (
|
|
||||||
st.session_state.get("api_url", None) is not None
|
|
||||||
and st.session_state.get("proxy_key", None) is not None
|
|
||||||
):
|
|
||||||
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
|
|
||||||
|
|
||||||
models = st.text_input("Models it can access (separated by comma)", value="")
|
|
||||||
models = models.split(",") if models else []
|
|
||||||
|
|
||||||
additional_params = json.loads(
|
|
||||||
st.text_area(
|
|
||||||
"Additional Key Params (JSON dictionary). [See all possible inputs](https://litellm-api.up.railway.app/#/key%20management/generate_key_fn_key_generate_post)",
|
|
||||||
value={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if st.button("Submit"):
|
|
||||||
try:
|
|
||||||
key_post_body = {
|
|
||||||
"duration": duration,
|
|
||||||
"models": models,
|
|
||||||
**additional_params,
|
|
||||||
}
|
|
||||||
# Make the POST request to the specified URL
|
|
||||||
complete_url = ""
|
|
||||||
if st.session_state["api_url"].endswith("/"):
|
|
||||||
complete_url = f"{st.session_state['api_url']}key/generate"
|
|
||||||
else:
|
|
||||||
complete_url = f"{st.session_state['api_url']}/key/generate"
|
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {st.session_state['proxy_key']}"}
|
|
||||||
response = requests.post(
|
|
||||||
complete_url, json=key_post_body, headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
st.success(f"Key added successfully! - {response.json()}")
|
|
||||||
else:
|
|
||||||
st.error(f"Failed to add Key. Status code: {response.status_code}")
|
|
||||||
|
|
||||||
st.success("Form submitted successfully!")
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
st.warning(
|
|
||||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def streamlit_ui():
|
|
||||||
import streamlit as st
|
|
||||||
|
|
||||||
st.header("Admin Configuration")
|
|
||||||
|
|
||||||
# Add a navigation sidebar
|
|
||||||
st.sidebar.title("Navigation")
|
|
||||||
page = st.sidebar.radio(
|
|
||||||
"Go to", ("Proxy Setup", "Add Models", "List Models", "Create Key")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize session state variables if not already present
|
|
||||||
if "api_url" not in st.session_state:
|
|
||||||
st.session_state["api_url"] = None
|
|
||||||
if "proxy_key" not in st.session_state:
|
|
||||||
st.session_state["proxy_key"] = None
|
|
||||||
|
|
||||||
# Display different pages based on navigation selection
|
|
||||||
if page == "Proxy Setup":
|
|
||||||
# Use text inputs with intermediary variables
|
|
||||||
input_api_url = st.text_input(
|
|
||||||
"Proxy Endpoint",
|
|
||||||
value=st.session_state.get("api_url", ""),
|
|
||||||
placeholder="http://0.0.0.0:8000",
|
|
||||||
)
|
|
||||||
input_proxy_key = st.text_input(
|
|
||||||
"Proxy Key",
|
|
||||||
value=st.session_state.get("proxy_key", ""),
|
|
||||||
placeholder="sk-...",
|
|
||||||
)
|
|
||||||
# When the "Save" button is clicked, update the session state
|
|
||||||
if st.button("Save"):
|
|
||||||
st.session_state["api_url"] = input_api_url
|
|
||||||
st.session_state["proxy_key"] = input_proxy_key
|
|
||||||
st.success("Configuration saved!")
|
|
||||||
elif page == "Add Models":
|
|
||||||
add_new_model()
|
|
||||||
elif page == "List Models":
|
|
||||||
list_models()
|
|
||||||
elif page == "Create Key":
|
|
||||||
create_key()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
streamlit_ui()
|
|
|
@ -69,7 +69,7 @@ async def get_global_activity(
|
||||||
try:
|
try:
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
|
|
|
@ -132,7 +132,7 @@ def common_checks(
|
||||||
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
if (
|
if (
|
||||||
general_settings.get("enforce_user_param", None) is not None
|
general_settings.get("enforce_user_param", None) is not None
|
||||||
and general_settings["enforce_user_param"] == True
|
and general_settings["enforce_user_param"] is True
|
||||||
):
|
):
|
||||||
if is_llm_api_route(route=route) and "user" not in request_body:
|
if is_llm_api_route(route=route) and "user" not in request_body:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -557,7 +557,7 @@ async def get_team_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
return _response
|
return _response
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||||
)
|
)
|
||||||
|
@ -664,7 +664,7 @@ async def get_org_object(
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
|
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
|
||||||
)
|
)
|
||||||
|
|
|
@ -98,7 +98,7 @@ class LicenseCheck:
|
||||||
elif self._verify(license_str=self.license_str) is True:
|
elif self._verify(license_str=self.license_str) is True:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def verify_license_without_api_request(self, public_key, license_key):
|
def verify_license_without_api_request(self, public_key, license_key):
|
||||||
|
|
|
@ -112,7 +112,6 @@ async def user_api_key_auth(
|
||||||
),
|
),
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
custom_db_client,
|
|
||||||
general_settings,
|
general_settings,
|
||||||
jwt_handler,
|
jwt_handler,
|
||||||
litellm_proxy_admin_name,
|
litellm_proxy_admin_name,
|
||||||
|
@ -476,7 +475,7 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
if route == "/user/auth":
|
if route == "/user/auth":
|
||||||
if general_settings.get("allow_user_auth", False) == True:
|
if general_settings.get("allow_user_auth", False) is True:
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth()
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -597,7 +596,7 @@ async def user_api_key_auth(
|
||||||
## VALIDATE MASTER KEY ##
|
## VALIDATE MASTER KEY ##
|
||||||
try:
|
try:
|
||||||
assert isinstance(master_key, str)
|
assert isinstance(master_key, str)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={
|
detail={
|
||||||
|
@ -648,7 +647,7 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
prisma_client is None and custom_db_client is None
|
prisma_client is None
|
||||||
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||||
raise Exception("No connected db.")
|
raise Exception("No connected db.")
|
||||||
|
|
||||||
|
@ -722,9 +721,9 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
if config != {}:
|
if config != {}:
|
||||||
model_list = config.get("model_list", [])
|
model_list = config.get("model_list", [])
|
||||||
llm_model_list = model_list
|
new_model_list = model_list
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"\n new llm router model list {llm_model_list}"
|
f"\n new llm router model list {new_model_list}"
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
len(valid_token.models) == 0
|
len(valid_token.models) == 0
|
||||||
|
|
|
@ -2,6 +2,7 @@ import sys
|
||||||
from typing import Any, Dict, List, Optional, get_args
|
from typing import Any, Dict, List, Optional, get_args
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm import get_secret, get_secret_str
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
|
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
|
||||||
from litellm.proxy.utils import get_instance_fn
|
from litellm.proxy.utils import get_instance_fn
|
||||||
|
@ -59,9 +60,15 @@ def initialize_callbacks_on_proxy(
|
||||||
presidio_logging_only
|
presidio_logging_only
|
||||||
) # validate boolean given
|
) # validate boolean given
|
||||||
|
|
||||||
params = {
|
_presidio_params = {}
|
||||||
|
if "presidio" in callback_specific_params and isinstance(
|
||||||
|
callback_specific_params["presidio"], dict
|
||||||
|
):
|
||||||
|
_presidio_params = callback_specific_params["presidio"]
|
||||||
|
|
||||||
|
params: Dict[str, Any] = {
|
||||||
"logging_only": presidio_logging_only,
|
"logging_only": presidio_logging_only,
|
||||||
**callback_specific_params.get("presidio", {}),
|
**_presidio_params,
|
||||||
}
|
}
|
||||||
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
|
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
|
||||||
imported_list.append(pii_masking_object)
|
imported_list.append(pii_masking_object)
|
||||||
|
@ -70,7 +77,7 @@ def initialize_callbacks_on_proxy(
|
||||||
_ENTERPRISE_LlamaGuard,
|
_ENTERPRISE_LlamaGuard,
|
||||||
)
|
)
|
||||||
|
|
||||||
if premium_user != True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Trying to use Llama Guard"
|
"Trying to use Llama Guard"
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
|
@ -83,7 +90,7 @@ def initialize_callbacks_on_proxy(
|
||||||
_ENTERPRISE_SecretDetection,
|
_ENTERPRISE_SecretDetection,
|
||||||
)
|
)
|
||||||
|
|
||||||
if premium_user != True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Trying to use secret hiding"
|
"Trying to use secret hiding"
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
|
@ -96,7 +103,7 @@ def initialize_callbacks_on_proxy(
|
||||||
_ENTERPRISE_OpenAI_Moderation,
|
_ENTERPRISE_OpenAI_Moderation,
|
||||||
)
|
)
|
||||||
|
|
||||||
if premium_user != True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Trying to use OpenAI Moderations Check"
|
"Trying to use OpenAI Moderations Check"
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
|
@ -126,7 +133,7 @@ def initialize_callbacks_on_proxy(
|
||||||
_ENTERPRISE_GoogleTextModeration,
|
_ENTERPRISE_GoogleTextModeration,
|
||||||
)
|
)
|
||||||
|
|
||||||
if premium_user != True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Trying to use Google Text Moderation"
|
"Trying to use Google Text Moderation"
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
|
@ -137,7 +144,7 @@ def initialize_callbacks_on_proxy(
|
||||||
elif isinstance(callback, str) and callback == "llmguard_moderations":
|
elif isinstance(callback, str) and callback == "llmguard_moderations":
|
||||||
from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
|
from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
|
||||||
|
|
||||||
if premium_user != True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Trying to use Llm Guard"
|
"Trying to use Llm Guard"
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
|
@ -150,7 +157,7 @@ def initialize_callbacks_on_proxy(
|
||||||
_ENTERPRISE_BlockedUserList,
|
_ENTERPRISE_BlockedUserList,
|
||||||
)
|
)
|
||||||
|
|
||||||
if premium_user != True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Trying to use ENTERPRISE BlockedUser"
|
"Trying to use ENTERPRISE BlockedUser"
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
|
@ -165,7 +172,7 @@ def initialize_callbacks_on_proxy(
|
||||||
_ENTERPRISE_BannedKeywords,
|
_ENTERPRISE_BannedKeywords,
|
||||||
)
|
)
|
||||||
|
|
||||||
if premium_user != True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Trying to use ENTERPRISE BannedKeyword"
|
"Trying to use ENTERPRISE BannedKeyword"
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
|
@ -212,7 +219,7 @@ def initialize_callbacks_on_proxy(
|
||||||
and isinstance(v, str)
|
and isinstance(v, str)
|
||||||
and v.startswith("os.environ/")
|
and v.startswith("os.environ/")
|
||||||
):
|
):
|
||||||
azure_content_safety_params[k] = litellm.get_secret(v)
|
azure_content_safety_params[k] = get_secret(v)
|
||||||
|
|
||||||
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
||||||
**azure_content_safety_params,
|
**azure_content_safety_params,
|
||||||
|
|
|
@ -6,13 +6,14 @@ import tracemalloc
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm import get_secret, get_secret_str
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
|
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
|
||||||
try:
|
try:
|
||||||
import objgraph
|
import objgraph # type: ignore
|
||||||
|
|
||||||
print("growth of objects") # noqa
|
print("growth of objects") # noqa
|
||||||
objgraph.show_growth()
|
objgraph.show_growth()
|
||||||
|
@ -21,8 +22,10 @@ if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
|
||||||
roots = objgraph.get_leaking_objects()
|
roots = objgraph.get_leaking_objects()
|
||||||
print("\n\nLeaking objects") # noqa
|
print("\n\nLeaking objects") # noqa
|
||||||
objgraph.show_most_common_types(objects=roots)
|
objgraph.show_most_common_types(objects=roots)
|
||||||
except:
|
except ImportError:
|
||||||
pass
|
raise ImportError(
|
||||||
|
"objgraph not found. Please install objgraph to use this feature."
|
||||||
|
)
|
||||||
|
|
||||||
tracemalloc.start(10)
|
tracemalloc.start(10)
|
||||||
|
|
||||||
|
@ -57,15 +60,20 @@ async def memory_usage_in_mem_cache():
|
||||||
user_api_key_cache,
|
user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if llm_router is None:
|
||||||
|
num_items_in_llm_router_cache = 0
|
||||||
|
else:
|
||||||
|
num_items_in_llm_router_cache = len(
|
||||||
|
llm_router.cache.in_memory_cache.cache_dict
|
||||||
|
) + len(llm_router.cache.in_memory_cache.ttl_dict)
|
||||||
|
|
||||||
num_items_in_user_api_key_cache = len(
|
num_items_in_user_api_key_cache = len(
|
||||||
user_api_key_cache.in_memory_cache.cache_dict
|
user_api_key_cache.in_memory_cache.cache_dict
|
||||||
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
|
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
|
||||||
num_items_in_llm_router_cache = len(
|
|
||||||
llm_router.cache.in_memory_cache.cache_dict
|
|
||||||
) + len(llm_router.cache.in_memory_cache.ttl_dict)
|
|
||||||
num_items_in_proxy_logging_obj_cache = len(
|
num_items_in_proxy_logging_obj_cache = len(
|
||||||
proxy_logging_obj.internal_usage_cache.in_memory_cache.cache_dict
|
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||||
) + len(proxy_logging_obj.internal_usage_cache.in_memory_cache.ttl_dict)
|
) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
|
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
|
||||||
|
@ -89,13 +97,20 @@ async def memory_usage_in_mem_cache_items():
|
||||||
user_api_key_cache,
|
user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if llm_router is None:
|
||||||
|
llm_router_in_memory_cache_dict = {}
|
||||||
|
llm_router_in_memory_ttl_dict = {}
|
||||||
|
else:
|
||||||
|
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
|
||||||
|
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
|
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
|
||||||
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
|
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
|
||||||
"llm_router_cache": llm_router.cache.in_memory_cache.cache_dict,
|
"llm_router_cache": llm_router_in_memory_cache_dict,
|
||||||
"llm_router_ttl": llm_router.cache.in_memory_cache.ttl_dict,
|
"llm_router_ttl": llm_router_in_memory_ttl_dict,
|
||||||
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.in_memory_cache.cache_dict,
|
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
|
||||||
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.in_memory_cache.ttl_dict,
|
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,9 +119,18 @@ async def get_otel_spans():
|
||||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||||
|
|
||||||
open_telemetry_logger: OpenTelemetry = open_telemetry_logger
|
if open_telemetry_logger is None:
|
||||||
|
return {
|
||||||
|
"otel_spans": [],
|
||||||
|
"spans_grouped_by_parent": {},
|
||||||
|
"most_recent_parent": None,
|
||||||
|
}
|
||||||
|
|
||||||
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
|
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
|
||||||
recorded_spans = otel_exporter.get_finished_spans()
|
if hasattr(otel_exporter, "get_finished_spans"):
|
||||||
|
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
|
||||||
|
else:
|
||||||
|
recorded_spans = []
|
||||||
|
|
||||||
print("Spans: ", recorded_spans) # noqa
|
print("Spans: ", recorded_spans) # noqa
|
||||||
|
|
||||||
|
@ -137,11 +161,13 @@ async def get_otel_spans():
|
||||||
# Helper functions for debugging
|
# Helper functions for debugging
|
||||||
def init_verbose_loggers():
|
def init_verbose_loggers():
|
||||||
try:
|
try:
|
||||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
worker_config = get_secret_str("WORKER_CONFIG")
|
||||||
|
# if not, assume it's a json string
|
||||||
|
if worker_config is None:
|
||||||
|
return
|
||||||
if os.path.isfile(worker_config):
|
if os.path.isfile(worker_config):
|
||||||
return
|
return
|
||||||
# if not, assume it's a json string
|
_settings = json.loads(worker_config)
|
||||||
_settings = json.loads(os.getenv("WORKER_CONFIG"))
|
|
||||||
if not isinstance(_settings, dict):
|
if not isinstance(_settings, dict):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -162,7 +188,7 @@ def init_verbose_loggers():
|
||||||
level=logging.INFO
|
level=logging.INFO
|
||||||
) # set router logs to info
|
) # set router logs to info
|
||||||
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
||||||
if detailed_debug == True:
|
if detailed_debug is True:
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from litellm._logging import (
|
from litellm._logging import (
|
||||||
|
@ -178,10 +204,10 @@ def init_verbose_loggers():
|
||||||
verbose_proxy_logger.setLevel(
|
verbose_proxy_logger.setLevel(
|
||||||
level=logging.DEBUG
|
level=logging.DEBUG
|
||||||
) # set proxy logs to debug
|
) # set proxy logs to debug
|
||||||
elif debug == False and detailed_debug == False:
|
elif debug is False and detailed_debug is False:
|
||||||
# users can control proxy debugging using env variable = 'LITELLM_LOG'
|
# users can control proxy debugging using env variable = 'LITELLM_LOG'
|
||||||
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
||||||
if litellm_log_setting != None:
|
if litellm_log_setting is not None:
|
||||||
if litellm_log_setting.upper() == "INFO":
|
if litellm_log_setting.upper() == "INFO":
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -213,4 +239,6 @@ def init_verbose_loggers():
|
||||||
level=logging.DEBUG
|
level=logging.DEBUG
|
||||||
) # set proxy logs to debug
|
) # set proxy logs to debug
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.warning(f"Failed to init verbose loggers: {str(e)}")
|
import logging
|
||||||
|
|
||||||
|
logging.warning(f"Failed to init verbose loggers: {str(e)}")
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue