diff --git a/Dockerfile.database b/Dockerfile.database index 22084bab8..1901200d5 100644 --- a/Dockerfile.database +++ b/Dockerfile.database @@ -9,6 +9,27 @@ FROM $LITELLM_BUILD_IMAGE as builder # Set the working directory to /app WORKDIR /app +ARG LITELLM_USER=litellm LITELLM_UID=1729 +ARG LITELLM_GROUP=litellm LITELLM_GID=1729 + +RUN groupadd \ + --gid ${LITELLM_GID} \ + ${LITELLM_GROUP} \ + && useradd \ + --create-home \ + --shell /bin/sh \ + --gid ${LITELLM_GID} \ + --uid ${LITELLM_UID} \ + ${LITELLM_USER} + +# Allows user to update python install. +# This is necessary for prisma. +RUN chown -R ${LITELLM_USER}:${LITELLM_GROUP} /usr/local/lib/python3.11 + +# Set the HOME var forcefully because of prisma. +ENV HOME=/home/${LITELLM_USER} +USER ${LITELLM_USER} + # Install build dependencies RUN apt-get clean && apt-get update && \ apt-get install -y gcc python3-dev && \ diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index 8ad294457..534f8e26f 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -1,27 +1,28 @@ # What is this? ## Controller file for Predibase Integration - https://predibase.com/ -from functools import partial -import os, types -import traceback +import copy import json -from enum import Enum -import requests, copy # type: ignore +import os import time -from typing import Callable, Optional, List, Literal, Union -from litellm.utils import ( - ModelResponse, - Usage, - CustomStreamWrapper, - Message, - Choices, -) -from litellm.litellm_core_utils.core_helpers import map_finish_reason -import litellm -from .prompt_templates.factory import prompt_factory, custom_prompt -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler -from .base import BaseLLM +import traceback +import types +from enum import Enum +from functools import partial +from typing import Callable, List, Literal, Optional, Union + import httpx # type: ignore +import requests # type: ignore + +import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage + +from .base import BaseLLM +from .prompt_templates.factory import custom_prompt, prompt_factory class PredibaseError(Exception): @@ -146,7 +147,49 @@ class PredibaseConfig: } def get_supported_openai_params(self): - return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "n", + "response_format", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + value = 0.01 + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + if param == "echo": + # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details + # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False + optional_params["decoder_input_details"] = True + if param == "response_format": + optional_params["response_format"] = value + return optional_params class PredibaseChatCompletion(BaseLLM): @@ -225,15 +268,16 @@ class PredibaseChatCompletion(BaseLLM): status_code=response.status_code, ) else: - if ( - not isinstance(completion_response, dict) - or "generated_text" not in completion_response - ): + if not isinstance(completion_response, dict): raise PredibaseError( status_code=422, - message=f"response is not in expected format - {completion_response}", + message=f"'completion_response' is not a dictionary - {completion_response}", + ) + elif "generated_text" not in completion_response: + raise PredibaseError( + status_code=422, + message=f"'generated_text' is not a key response dictionary - {completion_response}", ) - if len(completion_response["generated_text"]) > 0: model_response["choices"][0]["message"]["content"] = self.output_parser( completion_response["generated_text"] @@ -496,7 +540,9 @@ class PredibaseChatCompletion(BaseLLM): except httpx.HTTPStatusError as e: raise PredibaseError( status_code=e.response.status_code, - message="HTTPStatusError - {}".format(e.response.text), + message="HTTPStatusError - received status_code={}, error_message={}".format( + e.response.status_code, e.response.text + ), ) except Exception as e: raise PredibaseError( diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index c5f1b4768..2060f61ca 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -14,13 +14,10 @@ model_list: - model_name: fake-openai-endpoint litellm_params: model: predibase/llama-3-8b-instruct - api_base: "http://0.0.0.0:8000" + api_base: "http://0.0.0.0:8081" api_key: os.environ/PREDIBASE_API_KEY tenant_id: os.environ/PREDIBASE_TENANT_ID - max_retries: 0 - temperature: 0.1 max_new_tokens: 256 - return_full_text: false # - litellm_params: # api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ @@ -73,6 +70,8 @@ model_list: litellm_settings: callbacks: ["dynamic_rate_limiter"] + # success_callback: ["langfuse"] + # failure_callback: ["langfuse"] # default_team_settings: # - team_id: proj1 # success_callback: ["langfuse"] @@ -94,8 +93,8 @@ assistant_settings: router_settings: enable_pre_call_checks: true -general_settings: - alerting: ["slack"] - enable_jwt_auth: True - litellm_jwtauth: - team_id_jwt_field: "client_id" \ No newline at end of file +# general_settings: +# # alerting: ["slack"] +# enable_jwt_auth: True +# litellm_jwtauth: +# team_id_jwt_field: "client_id" \ No newline at end of file diff --git a/litellm/proxy/auth/litellm_license.py b/litellm/proxy/auth/litellm_license.py index ffd9f5273..0310dcaf5 100644 --- a/litellm/proxy/auth/litellm_license.py +++ b/litellm/proxy/auth/litellm_license.py @@ -1,6 +1,11 @@ # What is this? ## If litellm license in env, checks if it's valid +import base64 +import json import os +from datetime import datetime + +from litellm._logging import verbose_proxy_logger from litellm.llms.custom_httpx.http_handler import HTTPHandler @@ -15,6 +20,26 @@ class LicenseCheck: def __init__(self) -> None: self.license_str = os.getenv("LITELLM_LICENSE", None) self.http_handler = HTTPHandler() + self.public_key = None + self.read_public_key() + + def read_public_key(self): + try: + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import padding, rsa + + # current dir + current_dir = os.path.dirname(os.path.realpath(__file__)) + + # check if public_key.pem exists + _path_to_public_key = os.path.join(current_dir, "public_key.pem") + if os.path.exists(_path_to_public_key): + with open(_path_to_public_key, "rb") as key_file: + self.public_key = serialization.load_pem_public_key(key_file.read()) + else: + self.public_key = None + except Exception as e: + verbose_proxy_logger.error(f"Error reading public key: {str(e)}") def _verify(self, license_str: str) -> bool: url = "{}/verify_license/{}".format(self.base_url, license_str) @@ -35,11 +60,58 @@ class LicenseCheck: return False def is_premium(self) -> bool: + """ + 1. verify_license_without_api_request: checks if license was generate using private / public key pair + 2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license + """ try: if self.license_str is None: return False + elif self.verify_license_without_api_request( + public_key=self.public_key, license_key=self.license_str + ): + return True elif self._verify(license_str=self.license_str): return True return False except Exception as e: return False + + def verify_license_without_api_request(self, public_key, license_key): + try: + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import padding, rsa + + # Decode the license key + decoded = base64.b64decode(license_key) + message, signature = decoded.split(b".", 1) + + # Verify the signature + public_key.verify( + signature, + message, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH, + ), + hashes.SHA256(), + ) + + # Decode and parse the data + license_data = json.loads(message.decode()) + + # debug information provided in license data + verbose_proxy_logger.debug("License data: %s", license_data) + + # Check expiration date + expiration_date = datetime.strptime( + license_data["expiration_date"], "%Y-%m-%d" + ) + if expiration_date < datetime.now(): + return False, "License has expired" + + return True + + except Exception as e: + verbose_proxy_logger.error(str(e)) + return False diff --git a/litellm/proxy/auth/public_key.pem b/litellm/proxy/auth/public_key.pem new file mode 100644 index 000000000..12a69dde2 --- /dev/null +++ b/litellm/proxy/auth/public_key.pem @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmfBuNiNzDkNWyce23koQ +w0vq3bSVHkq7fd9Sw/U1q7FwRwL221daLTyGWssd8xAoQSFXAJKoBwzJQ9wd+o44 +lfL54E3a61nfjZuF+D9ntpXZFfEAxLVtIahDeQjUz4b/EpgciWIJyUfjCJrQo6LY +eyAZPTGSO8V3zHyaU+CFywq5XCuCnfZqCZeCw051St59A2v8W32mXSCJ+A+x0hYP +yXJyRRFcefSFG5IBuRHr4Y24Vx7NUIAoco5cnxJho9g2z3J/Hb0GKW+oBNvRVumk +nuA2Ljmjh4yI0OoTIW8ZWxemvCCJHSjdfKlMyb+QI4fmeiIUZzP5Au+F561Styqq +YQIDAQAB +-----END PUBLIC KEY----- diff --git a/litellm/utils.py b/litellm/utils.py index 9f6ebaff0..4465c5b0a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2609,7 +2609,15 @@ def get_optional_params( optional_params["top_p"] = top_p if stop is not None: optional_params["stop_sequences"] = stop - elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase": + elif custom_llm_provider == "predibase": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.PredibaseConfig().map_openai_params( + non_default_params=non_default_params, optional_params=optional_params + ) + elif custom_llm_provider == "huggingface": ## check if unsupported param passed in supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -6157,13 +6165,6 @@ def exception_type( response=original_exception.response, litellm_debug_info=extra_information, ) - if "Request failed during generation" in error_str: - # this is an internal server error from predibase - raise litellm.InternalServerError( - message=f"PredibaseException - {error_str}", - llm_provider="predibase", - model=model, - ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 500: exception_mapping_worked = True @@ -6201,7 +6202,10 @@ def exception_type( llm_provider=custom_llm_provider, litellm_debug_info=extra_information, ) - elif original_exception.status_code == 422: + elif ( + original_exception.status_code == 422 + or original_exception.status_code == 424 + ): exception_mapping_worked = True raise BadRequestError( message=f"PredibaseException - {original_exception.message}", diff --git a/requirements.txt b/requirements.txt index e40c44e4d..00d3802da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,7 @@ opentelemetry-api==1.25.0 opentelemetry-sdk==1.25.0 opentelemetry-exporter-otlp==1.25.0 detect-secrets==1.5.0 # Enterprise - secret detection / masking in LLM requests +cryptography==42.0.7 ### LITELLM PACKAGE DEPENDENCIES python-dotenv==1.0.0 # for env