diff --git a/litellm/__init__.py b/litellm/__init__.py index adb49c739e..81e386f0f9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1084,8 +1084,8 @@ from .llms.deprecated_providers.palm import ( PalmConfig, ) # here to prevent breaking changes from .llms.nlp_cloud.chat.handler import NLPCloudConfig +from .llms.petals.completion.transformation import PetalsConfig from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig -from .llms.petals import PetalsConfig from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( VertexGeminiConfig, GoogleAIStudioGeminiConfig, diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index 40753839d9..f34bed8b29 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -172,7 +172,7 @@ def get_supported_openai_params( # noqa: PLR0915 elif custom_llm_provider == "nlp_cloud": return litellm.NLPCloudConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "petals": - return ["max_tokens", "temperature", "top_p", "stream"] + return litellm.PetalsConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "deepinfra": return litellm.DeepInfraConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "perplexity": diff --git a/litellm/llms/petals/common_utils.py b/litellm/llms/petals/common_utils.py new file mode 100644 index 0000000000..9df4bad8eb --- /dev/null +++ b/litellm/llms/petals/common_utils.py @@ -0,0 +1,10 @@ +from typing import Union + +from httpx import Headers + +from litellm.llms.base_llm.transformation import BaseLLMException + + +class PetalsError(BaseLLMException): + def __init__(self, status_code: int, message: str, headers: Union[dict, Headers]): + super().__init__(status_code=status_code, message=message, headers=headers) diff --git a/litellm/llms/petals.py b/litellm/llms/petals/completion/handler.py similarity index 57% rename from litellm/llms/petals.py rename to litellm/llms/petals/completion/handler.py index 9194781f4f..108a8a334a 100644 --- a/litellm/llms/petals.py +++ b/litellm/llms/petals/completion/handler.py @@ -3,91 +3,21 @@ import os import time import types from enum import Enum -from typing import Callable, Optional - -import requests # type: ignore +from typing import Callable, Optional, Union import litellm +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, +) from litellm.utils import ModelResponse, Usage -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory - - -class PetalsError(Exception): - def __init__(self, status_code, message): - self.status_code = status_code - self.message = message - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs - - -class PetalsConfig: - """ - Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate - The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below: - - - `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens. - - - `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix). - - The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library: - - - `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below: - - - `temperature` (float, optional): This value sets the temperature for sampling. - - - `top_k` (integer, optional): This value sets the limit for top-k sampling. - - - `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling. - - - `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper. - """ - - max_length: Optional[int] = None - max_new_tokens: Optional[int] = ( - litellm.max_tokens - ) # petals requires max tokens to be set - do_sample: Optional[bool] = None - temperature: Optional[float] = None - top_k: Optional[int] = None - top_p: Optional[float] = None - repetition_penalty: Optional[float] = None - - def __init__( - self, - max_length: Optional[int] = None, - max_new_tokens: Optional[ - int - ] = litellm.max_tokens, # petals requires max tokens to be set - do_sample: Optional[bool] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - repetition_penalty: Optional[float] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } +from ..common_utils import PetalsError def completion( @@ -102,6 +32,7 @@ def completion( stream=False, litellm_params=None, logger_fn=None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): ## Load Config config = litellm.PetalsConfig.get_config() @@ -137,7 +68,9 @@ def completion( data = {"model": model, "inputs": prompt, **optional_params} ## COMPLETION CALL - response = requests.post(api_base, data=data) + if client is None or not isinstance(client, HTTPHandler): + client = _get_httpx_client() + response = client.post(api_base, data=data) ## LOGGING logging_obj.post_call( @@ -151,7 +84,11 @@ def completion( try: output_text = response.json()["outputs"] except Exception as e: - PetalsError(status_code=response.status_code, message=str(e)) + PetalsError( + status_code=response.status_code, + message=str(e), + headers=response.headers, + ) else: try: diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py new file mode 100644 index 0000000000..17386a7df5 --- /dev/null +++ b/litellm/llms/petals/completion/transformation.py @@ -0,0 +1,136 @@ +import types +from typing import Any, List, Optional, Union + +from httpx import Headers, Response + +import litellm +from litellm.llms.base_llm.transformation import ( + BaseConfig, + BaseLLMException, + LiteLLMLoggingObj, +) +from litellm.types.llms.openai import AllMessageValues + +from ..common_utils import PetalsError + + +class PetalsConfig(BaseConfig): + """ + Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate + The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below: + + - `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens. + + - `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix). + + The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library: + + - `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below: + + - `temperature` (float, optional): This value sets the temperature for sampling. + + - `top_k` (integer, optional): This value sets the limit for top-k sampling. + + - `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling. + + - `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper. + """ + + max_length: Optional[int] = None + max_new_tokens: Optional[int] = ( + litellm.max_tokens + ) # petals requires max tokens to be set + do_sample: Optional[bool] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + repetition_penalty: Optional[float] = None + + def __init__( + self, + max_length: Optional[int] = None, + max_new_tokens: Optional[ + int + ] = litellm.max_tokens, # petals requires max tokens to be set + do_sample: Optional[bool] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return PetalsError( + status_code=status_code, message=error_message, headers=headers + ) + + def get_supported_openai_params(self, model: str) -> List: + return ["max_tokens", "temperature", "top_p", "stream"] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_new_tokens"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stream": + optional_params["stream"] = value + return optional_params + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + raise NotImplementedError( + "Petals transformation currently done in handler.py. [TODO] Move to the transformation.py" + ) + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: litellm.ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> litellm.ModelResponse: + raise NotImplementedError( + "Petals transformation currently done in handler.py. [TODO] Move to the transformation.py" + ) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + ) -> dict: + return {} diff --git a/litellm/main.py b/litellm/main.py index f9aec9e0ab..6becf201a3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -94,7 +94,7 @@ from .litellm_core_utils.prompt_templates.factory import ( stringify_json_tool_call_content, ) from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor -from .llms import baseten, maritalk, ollama_chat, petals +from .llms import baseten, maritalk, ollama_chat from .llms.anthropic.chat import AnthropicChatCompletion from .llms.azure.audio_transcriptions import AzureAudioTranscription from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params @@ -120,6 +120,7 @@ from .llms.openai.openai import OpenAIChatCompletion from .llms.openai.transcriptions.handler import OpenAIAudioTranscription from .llms.openai_like.chat.handler import OpenAILikeChatHandler from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler +from .llms.petals.completion import handler as petals_handler from .llms.predibase.chat.handler import PredibaseChatCompletion from .llms.replicate.chat.handler import completion as replicate_chat_completion from .llms.sagemaker.chat.handler import SagemakerChatHandler @@ -2791,7 +2792,7 @@ def completion( # type: ignore # noqa: PLR0915 custom_llm_provider = "petals" stream = optional_params.pop("stream", False) - model_response = petals.completion( + model_response = petals_handler.completion( model=model, messages=messages, api_base=api_base, @@ -2802,6 +2803,7 @@ def completion( # type: ignore # noqa: PLR0915 logger_fn=logger_fn, encoding=encoding, logging_obj=logging, + client=client, ) if stream is True: ## [BETA] # Fake streaming for petals diff --git a/litellm/utils.py b/litellm/utils.py index 1772e1337b..3bbab724a8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3335,15 +3335,16 @@ def get_optional_params( # noqa: PLR0915 model=model, custom_llm_provider=custom_llm_provider ) _check_valid_arg(supported_params=supported_params) - # max_new_tokens=1,temperature=0.9, top_p=0.6 - if max_tokens is not None: - optional_params["max_new_tokens"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if stream: - optional_params["stream"] = stream + optional_params = litellm.PetalsConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) elif custom_llm_provider == "deepinfra": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -6375,6 +6376,8 @@ class ProviderConfigManager: return litellm.PredibaseConfig() elif litellm.LlmProviders.TRITON == provider: return litellm.TritonConfig() + elif litellm.LlmProviders.PETALS == provider: + return litellm.PetalsConfig() return litellm.OpenAIGPTConfig() diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index e494c2aba5..1c8b04575c 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -3685,18 +3685,27 @@ def test_mistral_anyscale_stream(): # error_str = traceback.format_exc() # pytest.fail(f"Error occurred: {error_str}") -# test_completion_with_fallbacks_multiple_keys() -# def test_petals(): -# try: -# response = completion(model="petals-team/StableBeluga2", messages=messages) -# # Add any assertions here to check the response -# print(response) -# response = completion(model="petals-team/StableBeluga2", messages=messages) -# # Add any assertions here to check the response -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") +# test_completion_with_fallbacks_multiple_keys() +def test_petals(): + try: + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + with patch.object(client, "post") as mock_post: + try: + completion( + model="petals-team/StableBeluga2", + messages=messages, + client=client, + api_base="https://api.petals.dev", + ) + except Exception as e: + print(f"Error occurred: {e}") + mock_post.assert_called_once() + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # def test_baseten(): # try: diff --git a/tests/local_testing/test_config.py b/tests/local_testing/test_config.py index 2d49338539..a63816e8e2 100644 --- a/tests/local_testing/test_config.py +++ b/tests/local_testing/test_config.py @@ -323,8 +323,13 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders): # or provider == LlmProviders.BEDROCK # or provider == LlmProviders.BASETEN # or provider == LlmProviders.PETALS +# or provider == LlmProviders.SAGEMAKER +# or provider == LlmProviders.SAGEMAKER_CHAT +# or provider == LlmProviders.VLLM +# or provider == LlmProviders.OLLAMA # ): # continue + # config = ProviderConfigManager.get_provider_chat_config( # model="gpt-3.5-turbo", provider=LlmProviders(provider) # )