Krrish Dholakia 2024-12-11 01:10:12 -08:00
parent efbec4230b
commit 02dd0c6e7e
9 changed files with 209 additions and 107 deletions

View file

@ -1084,8 +1084,8 @@ from .llms.deprecated_providers.palm import (
PalmConfig, PalmConfig,
) # here to prevent breaking changes ) # here to prevent breaking changes
from .llms.nlp_cloud.chat.handler import NLPCloudConfig from .llms.nlp_cloud.chat.handler import NLPCloudConfig
from .llms.petals.completion.transformation import PetalsConfig
from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
GoogleAIStudioGeminiConfig, GoogleAIStudioGeminiConfig,

View file

@ -172,7 +172,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "nlp_cloud": elif custom_llm_provider == "nlp_cloud":
return litellm.NLPCloudConfig().get_supported_openai_params(model=model) return litellm.NLPCloudConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "petals": elif custom_llm_provider == "petals":
return ["max_tokens", "temperature", "top_p", "stream"] return litellm.PetalsConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepinfra": elif custom_llm_provider == "deepinfra":
return litellm.DeepInfraConfig().get_supported_openai_params(model=model) return litellm.DeepInfraConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "perplexity": elif custom_llm_provider == "perplexity":

View file

@ -0,0 +1,10 @@
from typing import Union
from httpx import Headers
from litellm.llms.base_llm.transformation import BaseLLMException
class PetalsError(BaseLLMException):
def __init__(self, status_code: int, message: str, headers: Union[dict, Headers]):
super().__init__(status_code=status_code, message=message, headers=headers)

View file

@ -3,91 +3,21 @@ import os
import time import time
import types import types
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional, Union
import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
)
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory from ..common_utils import PetalsError
class PetalsError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class PetalsConfig:
"""
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
- `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
- `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
- `temperature` (float, optional): This value sets the temperature for sampling.
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
"""
max_length: Optional[int] = None
max_new_tokens: Optional[int] = (
litellm.max_tokens
) # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion( def completion(
@ -102,6 +32,7 @@ def completion(
stream=False, stream=False,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
## Load Config ## Load Config
config = litellm.PetalsConfig.get_config() config = litellm.PetalsConfig.get_config()
@ -137,7 +68,9 @@ def completion(
data = {"model": model, "inputs": prompt, **optional_params} data = {"model": model, "inputs": prompt, **optional_params}
## COMPLETION CALL ## COMPLETION CALL
response = requests.post(api_base, data=data) if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
response = client.post(api_base, data=data)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -151,7 +84,11 @@ def completion(
try: try:
output_text = response.json()["outputs"] output_text = response.json()["outputs"]
except Exception as e: except Exception as e:
PetalsError(status_code=response.status_code, message=str(e)) PetalsError(
status_code=response.status_code,
message=str(e),
headers=response.headers,
)
else: else:
try: try:

View file

@ -0,0 +1,136 @@
import types
from typing import Any, List, Optional, Union
from httpx import Headers, Response
import litellm
from litellm.llms.base_llm.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from ..common_utils import PetalsError
class PetalsConfig(BaseConfig):
"""
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
- `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
- `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
- `temperature` (float, optional): This value sets the temperature for sampling.
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
"""
max_length: Optional[int] = None
max_new_tokens: Optional[int] = (
litellm.max_tokens
) # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return PetalsError(
status_code=status_code, message=error_message, headers=headers
)
def get_supported_openai_params(self, model: str) -> List:
return ["max_tokens", "temperature", "top_p", "stream"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_new_tokens"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stream":
optional_params["stream"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: litellm.ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
raise NotImplementedError(
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
return {}

View file

@ -94,7 +94,7 @@ from .litellm_core_utils.prompt_templates.factory import (
stringify_json_tool_call_content, stringify_json_tool_call_content,
) )
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import baseten, maritalk, ollama_chat, petals from .llms import baseten, maritalk, ollama_chat
from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.azure.audio_transcriptions import AzureAudioTranscription from .llms.azure.audio_transcriptions import AzureAudioTranscription
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
@ -120,6 +120,7 @@ from .llms.openai.openai import OpenAIChatCompletion
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
from .llms.openai_like.chat.handler import OpenAILikeChatHandler from .llms.openai_like.chat.handler import OpenAILikeChatHandler
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from .llms.petals.completion import handler as petals_handler
from .llms.predibase.chat.handler import PredibaseChatCompletion from .llms.predibase.chat.handler import PredibaseChatCompletion
from .llms.replicate.chat.handler import completion as replicate_chat_completion from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.chat.handler import SagemakerChatHandler
@ -2791,7 +2792,7 @@ def completion( # type: ignore # noqa: PLR0915
custom_llm_provider = "petals" custom_llm_provider = "petals"
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
model_response = petals.completion( model_response = petals_handler.completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
@ -2802,6 +2803,7 @@ def completion( # type: ignore # noqa: PLR0915
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
client=client,
) )
if stream is True: ## [BETA] if stream is True: ## [BETA]
# Fake streaming for petals # Fake streaming for petals

View file

@ -3335,15 +3335,16 @@ def get_optional_params( # noqa: PLR0915
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# max_new_tokens=1,temperature=0.9, top_p=0.6 optional_params = litellm.PetalsConfig().map_openai_params(
if max_tokens is not None: non_default_params=non_default_params,
optional_params["max_new_tokens"] = max_tokens optional_params=optional_params,
if temperature is not None: model=model,
optional_params["temperature"] = temperature drop_params=(
if top_p is not None: drop_params
optional_params["top_p"] = top_p if drop_params is not None and isinstance(drop_params, bool)
if stream: else False
optional_params["stream"] = stream ),
)
elif custom_llm_provider == "deepinfra": elif custom_llm_provider == "deepinfra":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -6375,6 +6376,8 @@ class ProviderConfigManager:
return litellm.PredibaseConfig() return litellm.PredibaseConfig()
elif litellm.LlmProviders.TRITON == provider: elif litellm.LlmProviders.TRITON == provider:
return litellm.TritonConfig() return litellm.TritonConfig()
elif litellm.LlmProviders.PETALS == provider:
return litellm.PetalsConfig()
return litellm.OpenAIGPTConfig() return litellm.OpenAIGPTConfig()

View file

@ -3685,18 +3685,27 @@ def test_mistral_anyscale_stream():
# error_str = traceback.format_exc() # error_str = traceback.format_exc()
# pytest.fail(f"Error occurred: {error_str}") # pytest.fail(f"Error occurred: {error_str}")
# test_completion_with_fallbacks_multiple_keys()
# def test_petals():
# try:
# response = completion(model="petals-team/StableBeluga2", messages=messages)
# # Add any assertions here to check the response
# print(response)
# response = completion(model="petals-team/StableBeluga2", messages=messages) # test_completion_with_fallbacks_multiple_keys()
# # Add any assertions here to check the response def test_petals():
# print(response) try:
# except Exception as e: from litellm.llms.custom_httpx.http_handler import HTTPHandler
# pytest.fail(f"Error occurred: {e}")
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
completion(
model="petals-team/StableBeluga2",
messages=messages,
client=client,
api_base="https://api.petals.dev",
)
except Exception as e:
print(f"Error occurred: {e}")
mock_post.assert_called_once()
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_baseten(): # def test_baseten():
# try: # try:

View file

@ -323,8 +323,13 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
# or provider == LlmProviders.BEDROCK # or provider == LlmProviders.BEDROCK
# or provider == LlmProviders.BASETEN # or provider == LlmProviders.BASETEN
# or provider == LlmProviders.PETALS # or provider == LlmProviders.PETALS
# or provider == LlmProviders.SAGEMAKER
# or provider == LlmProviders.SAGEMAKER_CHAT
# or provider == LlmProviders.VLLM
# or provider == LlmProviders.OLLAMA
# ): # ):
# continue # continue
# config = ProviderConfigManager.get_provider_chat_config( # config = ProviderConfigManager.get_provider_chat_config(
# model="gpt-3.5-turbo", provider=LlmProviders(provider) # model="gpt-3.5-turbo", provider=LlmProviders(provider)
# ) # )