forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_add_secret_detection
This commit is contained in:
commit
9942a5cbcf
7 changed files with 196 additions and 44 deletions
|
@ -9,6 +9,27 @@ FROM $LITELLM_BUILD_IMAGE as builder
|
||||||
# Set the working directory to /app
|
# Set the working directory to /app
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
ARG LITELLM_USER=litellm LITELLM_UID=1729
|
||||||
|
ARG LITELLM_GROUP=litellm LITELLM_GID=1729
|
||||||
|
|
||||||
|
RUN groupadd \
|
||||||
|
--gid ${LITELLM_GID} \
|
||||||
|
${LITELLM_GROUP} \
|
||||||
|
&& useradd \
|
||||||
|
--create-home \
|
||||||
|
--shell /bin/sh \
|
||||||
|
--gid ${LITELLM_GID} \
|
||||||
|
--uid ${LITELLM_UID} \
|
||||||
|
${LITELLM_USER}
|
||||||
|
|
||||||
|
# Allows user to update python install.
|
||||||
|
# This is necessary for prisma.
|
||||||
|
RUN chown -R ${LITELLM_USER}:${LITELLM_GROUP} /usr/local/lib/python3.11
|
||||||
|
|
||||||
|
# Set the HOME var forcefully because of prisma.
|
||||||
|
ENV HOME=/home/${LITELLM_USER}
|
||||||
|
USER ${LITELLM_USER}
|
||||||
|
|
||||||
# Install build dependencies
|
# Install build dependencies
|
||||||
RUN apt-get clean && apt-get update && \
|
RUN apt-get clean && apt-get update && \
|
||||||
apt-get install -y gcc python3-dev && \
|
apt-get install -y gcc python3-dev && \
|
||||||
|
|
|
@ -1,27 +1,28 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Controller file for Predibase Integration - https://predibase.com/
|
## Controller file for Predibase Integration - https://predibase.com/
|
||||||
|
|
||||||
from functools import partial
|
import copy
|
||||||
import os, types
|
|
||||||
import traceback
|
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
import os
|
||||||
import requests, copy # type: ignore
|
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, List, Literal, Union
|
import traceback
|
||||||
from litellm.utils import (
|
import types
|
||||||
ModelResponse,
|
from enum import Enum
|
||||||
Usage,
|
from functools import partial
|
||||||
CustomStreamWrapper,
|
from typing import Callable, List, Literal, Optional, Union
|
||||||
Message,
|
|
||||||
Choices,
|
|
||||||
)
|
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
|
||||||
import litellm
|
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
|
||||||
from .base import BaseLLM
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
import litellm.litellm_core_utils
|
||||||
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
|
|
||||||
class PredibaseError(Exception):
|
class PredibaseError(Exception):
|
||||||
|
@ -146,7 +147,49 @@ class PredibaseConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self):
|
||||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"response_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params["do_sample"] = (
|
||||||
|
True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
)
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
if param == "response_format":
|
||||||
|
optional_params["response_format"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class PredibaseChatCompletion(BaseLLM):
|
class PredibaseChatCompletion(BaseLLM):
|
||||||
|
@ -225,15 +268,16 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (
|
if not isinstance(completion_response, dict):
|
||||||
not isinstance(completion_response, dict)
|
|
||||||
or "generated_text" not in completion_response
|
|
||||||
):
|
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
message=f"response is not in expected format - {completion_response}",
|
message=f"'completion_response' is not a dictionary - {completion_response}",
|
||||||
|
)
|
||||||
|
elif "generated_text" not in completion_response:
|
||||||
|
raise PredibaseError(
|
||||||
|
status_code=422,
|
||||||
|
message=f"'generated_text' is not a key response dictionary - {completion_response}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(completion_response["generated_text"]) > 0:
|
if len(completion_response["generated_text"]) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
||||||
completion_response["generated_text"]
|
completion_response["generated_text"]
|
||||||
|
@ -496,7 +540,9 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
status_code=e.response.status_code,
|
status_code=e.response.status_code,
|
||||||
message="HTTPStatusError - {}".format(e.response.text),
|
message="HTTPStatusError - received status_code={}, error_message={}".format(
|
||||||
|
e.response.status_code, e.response.text
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
|
|
|
@ -14,13 +14,10 @@ model_list:
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: predibase/llama-3-8b-instruct
|
model: predibase/llama-3-8b-instruct
|
||||||
api_base: "http://0.0.0.0:8000"
|
api_base: "http://0.0.0.0:8081"
|
||||||
api_key: os.environ/PREDIBASE_API_KEY
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||||
max_retries: 0
|
|
||||||
temperature: 0.1
|
|
||||||
max_new_tokens: 256
|
max_new_tokens: 256
|
||||||
return_full_text: false
|
|
||||||
|
|
||||||
# - litellm_params:
|
# - litellm_params:
|
||||||
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
@ -73,6 +70,8 @@ model_list:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["dynamic_rate_limiter"]
|
callbacks: ["dynamic_rate_limiter"]
|
||||||
|
# success_callback: ["langfuse"]
|
||||||
|
# failure_callback: ["langfuse"]
|
||||||
# default_team_settings:
|
# default_team_settings:
|
||||||
# - team_id: proj1
|
# - team_id: proj1
|
||||||
# success_callback: ["langfuse"]
|
# success_callback: ["langfuse"]
|
||||||
|
@ -94,8 +93,8 @@ assistant_settings:
|
||||||
router_settings:
|
router_settings:
|
||||||
enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
|
||||||
general_settings:
|
# general_settings:
|
||||||
alerting: ["slack"]
|
# # alerting: ["slack"]
|
||||||
enable_jwt_auth: True
|
# enable_jwt_auth: True
|
||||||
litellm_jwtauth:
|
# litellm_jwtauth:
|
||||||
team_id_jwt_field: "client_id"
|
# team_id_jwt_field: "client_id"
|
|
@ -1,6 +1,11 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## If litellm license in env, checks if it's valid
|
## If litellm license in env, checks if it's valid
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,6 +20,26 @@ class LicenseCheck:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||||
self.http_handler = HTTPHandler()
|
self.http_handler = HTTPHandler()
|
||||||
|
self.public_key = None
|
||||||
|
self.read_public_key()
|
||||||
|
|
||||||
|
def read_public_key(self):
|
||||||
|
try:
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||||
|
|
||||||
|
# current dir
|
||||||
|
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
# check if public_key.pem exists
|
||||||
|
_path_to_public_key = os.path.join(current_dir, "public_key.pem")
|
||||||
|
if os.path.exists(_path_to_public_key):
|
||||||
|
with open(_path_to_public_key, "rb") as key_file:
|
||||||
|
self.public_key = serialization.load_pem_public_key(key_file.read())
|
||||||
|
else:
|
||||||
|
self.public_key = None
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
|
||||||
|
|
||||||
def _verify(self, license_str: str) -> bool:
|
def _verify(self, license_str: str) -> bool:
|
||||||
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||||
|
@ -35,11 +60,58 @@ class LicenseCheck:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_premium(self) -> bool:
|
def is_premium(self) -> bool:
|
||||||
|
"""
|
||||||
|
1. verify_license_without_api_request: checks if license was generate using private / public key pair
|
||||||
|
2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if self.license_str is None:
|
if self.license_str is None:
|
||||||
return False
|
return False
|
||||||
|
elif self.verify_license_without_api_request(
|
||||||
|
public_key=self.public_key, license_key=self.license_str
|
||||||
|
):
|
||||||
|
return True
|
||||||
elif self._verify(license_str=self.license_str):
|
elif self._verify(license_str=self.license_str):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def verify_license_without_api_request(self, public_key, license_key):
|
||||||
|
try:
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||||
|
|
||||||
|
# Decode the license key
|
||||||
|
decoded = base64.b64decode(license_key)
|
||||||
|
message, signature = decoded.split(b".", 1)
|
||||||
|
|
||||||
|
# Verify the signature
|
||||||
|
public_key.verify(
|
||||||
|
signature,
|
||||||
|
message,
|
||||||
|
padding.PSS(
|
||||||
|
mgf=padding.MGF1(hashes.SHA256()),
|
||||||
|
salt_length=padding.PSS.MAX_LENGTH,
|
||||||
|
),
|
||||||
|
hashes.SHA256(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode and parse the data
|
||||||
|
license_data = json.loads(message.decode())
|
||||||
|
|
||||||
|
# debug information provided in license data
|
||||||
|
verbose_proxy_logger.debug("License data: %s", license_data)
|
||||||
|
|
||||||
|
# Check expiration date
|
||||||
|
expiration_date = datetime.strptime(
|
||||||
|
license_data["expiration_date"], "%Y-%m-%d"
|
||||||
|
)
|
||||||
|
if expiration_date < datetime.now():
|
||||||
|
return False, "License has expired"
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(str(e))
|
||||||
|
return False
|
||||||
|
|
9
litellm/proxy/auth/public_key.pem
Normal file
9
litellm/proxy/auth/public_key.pem
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
-----BEGIN PUBLIC KEY-----
|
||||||
|
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmfBuNiNzDkNWyce23koQ
|
||||||
|
w0vq3bSVHkq7fd9Sw/U1q7FwRwL221daLTyGWssd8xAoQSFXAJKoBwzJQ9wd+o44
|
||||||
|
lfL54E3a61nfjZuF+D9ntpXZFfEAxLVtIahDeQjUz4b/EpgciWIJyUfjCJrQo6LY
|
||||||
|
eyAZPTGSO8V3zHyaU+CFywq5XCuCnfZqCZeCw051St59A2v8W32mXSCJ+A+x0hYP
|
||||||
|
yXJyRRFcefSFG5IBuRHr4Y24Vx7NUIAoco5cnxJho9g2z3J/Hb0GKW+oBNvRVumk
|
||||||
|
nuA2Ljmjh4yI0OoTIW8ZWxemvCCJHSjdfKlMyb+QI4fmeiIUZzP5Au+F561Styqq
|
||||||
|
YQIDAQAB
|
||||||
|
-----END PUBLIC KEY-----
|
|
@ -2609,7 +2609,15 @@ def get_optional_params(
|
||||||
optional_params["top_p"] = top_p
|
optional_params["top_p"] = top_p
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
optional_params["stop_sequences"] = stop
|
optional_params["stop_sequences"] = stop
|
||||||
elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase":
|
elif custom_llm_provider == "predibase":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.PredibaseConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params, optional_params=optional_params
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "huggingface":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -6157,13 +6165,6 @@ def exception_type(
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
if "Request failed during generation" in error_str:
|
|
||||||
# this is an internal server error from predibase
|
|
||||||
raise litellm.InternalServerError(
|
|
||||||
message=f"PredibaseException - {error_str}",
|
|
||||||
llm_provider="predibase",
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
elif hasattr(original_exception, "status_code"):
|
elif hasattr(original_exception, "status_code"):
|
||||||
if original_exception.status_code == 500:
|
if original_exception.status_code == 500:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -6201,7 +6202,10 @@ def exception_type(
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 422:
|
elif (
|
||||||
|
original_exception.status_code == 422
|
||||||
|
or original_exception.status_code == 424
|
||||||
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"PredibaseException - {original_exception.message}",
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
|
|
@ -32,6 +32,7 @@ opentelemetry-api==1.25.0
|
||||||
opentelemetry-sdk==1.25.0
|
opentelemetry-sdk==1.25.0
|
||||||
opentelemetry-exporter-otlp==1.25.0
|
opentelemetry-exporter-otlp==1.25.0
|
||||||
detect-secrets==1.5.0 # Enterprise - secret detection / masking in LLM requests
|
detect-secrets==1.5.0 # Enterprise - secret detection / masking in LLM requests
|
||||||
|
cryptography==42.0.7
|
||||||
|
|
||||||
### LITELLM PACKAGE DEPENDENCIES
|
### LITELLM PACKAGE DEPENDENCIES
|
||||||
python-dotenv==1.0.0 # for env
|
python-dotenv==1.0.0 # for env
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue