Krrish Dholakia 2024-12-11 01:10:12 -08:00
parent efbec4230b
commit 02dd0c6e7e
9 changed files with 209 additions and 107 deletions

View file

@ -1084,8 +1084,8 @@ from .llms.deprecated_providers.palm import (
PalmConfig,
) # here to prevent breaking changes
from .llms.nlp_cloud.chat.handler import NLPCloudConfig
from .llms.petals.completion.transformation import PetalsConfig
from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
GoogleAIStudioGeminiConfig,

View file

@ -172,7 +172,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "nlp_cloud":
return litellm.NLPCloudConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "petals":
return ["max_tokens", "temperature", "top_p", "stream"]
return litellm.PetalsConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepinfra":
return litellm.DeepInfraConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "perplexity":

View file

@ -0,0 +1,10 @@
from typing import Union
from httpx import Headers
from litellm.llms.base_llm.transformation import BaseLLMException
class PetalsError(BaseLLMException):
def __init__(self, status_code: int, message: str, headers: Union[dict, Headers]):
super().__init__(status_code=status_code, message=message, headers=headers)

View file

@ -3,91 +3,21 @@ import os
import time
import types
from enum import Enum
from typing import Callable, Optional
import requests # type: ignore
from typing import Callable, Optional, Union
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
)
from litellm.utils import ModelResponse, Usage
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
class PetalsError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class PetalsConfig:
"""
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
- `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
- `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
- `temperature` (float, optional): This value sets the temperature for sampling.
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
"""
max_length: Optional[int] = None
max_new_tokens: Optional[int] = (
litellm.max_tokens
) # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
from ..common_utils import PetalsError
def completion(
@ -102,6 +32,7 @@ def completion(
stream=False,
litellm_params=None,
logger_fn=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
## Load Config
config = litellm.PetalsConfig.get_config()
@ -137,7 +68,9 @@ def completion(
data = {"model": model, "inputs": prompt, **optional_params}
## COMPLETION CALL
response = requests.post(api_base, data=data)
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
response = client.post(api_base, data=data)
## LOGGING
logging_obj.post_call(
@ -151,7 +84,11 @@ def completion(
try:
output_text = response.json()["outputs"]
except Exception as e:
PetalsError(status_code=response.status_code, message=str(e))
PetalsError(
status_code=response.status_code,
message=str(e),
headers=response.headers,
)
else:
try:

View file

@ -0,0 +1,136 @@
import types
from typing import Any, List, Optional, Union
from httpx import Headers, Response
import litellm
from litellm.llms.base_llm.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from ..common_utils import PetalsError
class PetalsConfig(BaseConfig):
"""
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
- `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
- `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
- `temperature` (float, optional): This value sets the temperature for sampling.
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
"""
max_length: Optional[int] = None
max_new_tokens: Optional[int] = (
litellm.max_tokens
) # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return PetalsError(
status_code=status_code, message=error_message, headers=headers
)
def get_supported_openai_params(self, model: str) -> List:
return ["max_tokens", "temperature", "top_p", "stream"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_new_tokens"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stream":
optional_params["stream"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: litellm.ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
raise NotImplementedError(
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
return {}

View file

@ -94,7 +94,7 @@ from .litellm_core_utils.prompt_templates.factory import (
stringify_json_tool_call_content,
)
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import baseten, maritalk, ollama_chat, petals
from .llms import baseten, maritalk, ollama_chat
from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.azure.audio_transcriptions import AzureAudioTranscription
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
@ -120,6 +120,7 @@ from .llms.openai.openai import OpenAIChatCompletion
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from .llms.petals.completion import handler as petals_handler
from .llms.predibase.chat.handler import PredibaseChatCompletion
from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.sagemaker.chat.handler import SagemakerChatHandler
@ -2791,7 +2792,7 @@ def completion( # type: ignore # noqa: PLR0915
custom_llm_provider = "petals"
stream = optional_params.pop("stream", False)
model_response = petals.completion(
model_response = petals_handler.completion(
model=model,
messages=messages,
api_base=api_base,
@ -2802,6 +2803,7 @@ def completion( # type: ignore # noqa: PLR0915
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
client=client,
)
if stream is True: ## [BETA]
# Fake streaming for petals

View file

@ -3335,15 +3335,16 @@ def get_optional_params( # noqa: PLR0915
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# max_new_tokens=1,temperature=0.9, top_p=0.6
if max_tokens is not None:
optional_params["max_new_tokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stream:
optional_params["stream"] = stream
optional_params = litellm.PetalsConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "deepinfra":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -6375,6 +6376,8 @@ class ProviderConfigManager:
return litellm.PredibaseConfig()
elif litellm.LlmProviders.TRITON == provider:
return litellm.TritonConfig()
elif litellm.LlmProviders.PETALS == provider:
return litellm.PetalsConfig()
return litellm.OpenAIGPTConfig()

View file

@ -3685,18 +3685,27 @@ def test_mistral_anyscale_stream():
# error_str = traceback.format_exc()
# pytest.fail(f"Error occurred: {error_str}")
# test_completion_with_fallbacks_multiple_keys()
# def test_petals():
# try:
# response = completion(model="petals-team/StableBeluga2", messages=messages)
# # Add any assertions here to check the response
# print(response)
# response = completion(model="petals-team/StableBeluga2", messages=messages)
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_with_fallbacks_multiple_keys()
def test_petals():
try:
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
completion(
model="petals-team/StableBeluga2",
messages=messages,
client=client,
api_base="https://api.petals.dev",
)
except Exception as e:
print(f"Error occurred: {e}")
mock_post.assert_called_once()
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_baseten():
# try:

View file

@ -323,8 +323,13 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
# or provider == LlmProviders.BEDROCK
# or provider == LlmProviders.BASETEN
# or provider == LlmProviders.PETALS
# or provider == LlmProviders.SAGEMAKER
# or provider == LlmProviders.SAGEMAKER_CHAT
# or provider == LlmProviders.VLLM
# or provider == LlmProviders.OLLAMA
# ):
# continue
# config = ProviderConfigManager.get_provider_chat_config(
# model="gpt-3.5-turbo", provider=LlmProviders(provider)
# )