mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
build: Squashed commit of https://github.com/BerriAI/litellm/pull/7171
Closes https://github.com/BerriAI/litellm/pull/7171
This commit is contained in:
parent
efbec4230b
commit
02dd0c6e7e
9 changed files with 209 additions and 107 deletions
|
@ -1084,8 +1084,8 @@ from .llms.deprecated_providers.palm import (
|
|||
PalmConfig,
|
||||
) # here to prevent breaking changes
|
||||
from .llms.nlp_cloud.chat.handler import NLPCloudConfig
|
||||
from .llms.petals.completion.transformation import PetalsConfig
|
||||
from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
|
||||
from .llms.petals import PetalsConfig
|
||||
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
GoogleAIStudioGeminiConfig,
|
||||
|
|
|
@ -172,7 +172,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
elif custom_llm_provider == "nlp_cloud":
|
||||
return litellm.NLPCloudConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "petals":
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
return litellm.PetalsConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "deepinfra":
|
||||
return litellm.DeepInfraConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "perplexity":
|
||||
|
|
10
litellm/llms/petals/common_utils.py
Normal file
10
litellm/llms/petals/common_utils.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from typing import Union
|
||||
|
||||
from httpx import Headers
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
|
||||
|
||||
class PetalsError(BaseLLMException):
|
||||
def __init__(self, status_code: int, message: str, headers: Union[dict, Headers]):
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
@ -3,91 +3,21 @@ import os
|
|||
import time
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
import requests # type: ignore
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
||||
class PetalsError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class PetalsConfig:
|
||||
"""
|
||||
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
|
||||
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
|
||||
|
||||
- `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
|
||||
|
||||
- `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
|
||||
|
||||
The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
|
||||
|
||||
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
|
||||
|
||||
- `temperature` (float, optional): This value sets the temperature for sampling.
|
||||
|
||||
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
|
||||
|
||||
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
|
||||
|
||||
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
|
||||
"""
|
||||
|
||||
max_length: Optional[int] = None
|
||||
max_new_tokens: Optional[int] = (
|
||||
litellm.max_tokens
|
||||
) # petals requires max tokens to be set
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_length: Optional[int] = None,
|
||||
max_new_tokens: Optional[
|
||||
int
|
||||
] = litellm.max_tokens, # petals requires max tokens to be set
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
from ..common_utils import PetalsError
|
||||
|
||||
|
||||
def completion(
|
||||
|
@ -102,6 +32,7 @@ def completion(
|
|||
stream=False,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
## Load Config
|
||||
config = litellm.PetalsConfig.get_config()
|
||||
|
@ -137,7 +68,9 @@ def completion(
|
|||
data = {"model": model, "inputs": prompt, **optional_params}
|
||||
|
||||
## COMPLETION CALL
|
||||
response = requests.post(api_base, data=data)
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
response = client.post(api_base, data=data)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -151,7 +84,11 @@ def completion(
|
|||
try:
|
||||
output_text = response.json()["outputs"]
|
||||
except Exception as e:
|
||||
PetalsError(status_code=response.status_code, message=str(e))
|
||||
PetalsError(
|
||||
status_code=response.status_code,
|
||||
message=str(e),
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
136
litellm/llms/petals/completion/transformation.py
Normal file
136
litellm/llms/petals/completion/transformation.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
import types
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
from ..common_utils import PetalsError
|
||||
|
||||
|
||||
class PetalsConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
|
||||
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
|
||||
|
||||
- `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
|
||||
|
||||
- `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
|
||||
|
||||
The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
|
||||
|
||||
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
|
||||
|
||||
- `temperature` (float, optional): This value sets the temperature for sampling.
|
||||
|
||||
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
|
||||
|
||||
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
|
||||
|
||||
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
|
||||
"""
|
||||
|
||||
max_length: Optional[int] = None
|
||||
max_new_tokens: Optional[int] = (
|
||||
litellm.max_tokens
|
||||
) # petals requires max tokens to be set
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_length: Optional[int] = None,
|
||||
max_new_tokens: Optional[
|
||||
int
|
||||
] = litellm.max_tokens, # petals requires max tokens to be set
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return PetalsError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: litellm.ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
|
@ -94,7 +94,7 @@ from .litellm_core_utils.prompt_templates.factory import (
|
|||
stringify_json_tool_call_content,
|
||||
)
|
||||
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||
from .llms import baseten, maritalk, ollama_chat, petals
|
||||
from .llms import baseten, maritalk, ollama_chat
|
||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
||||
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
|
@ -120,6 +120,7 @@ from .llms.openai.openai import OpenAIChatCompletion
|
|||
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
||||
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||
from .llms.petals.completion import handler as petals_handler
|
||||
from .llms.predibase.chat.handler import PredibaseChatCompletion
|
||||
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||
|
@ -2791,7 +2792,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
|
||||
custom_llm_provider = "petals"
|
||||
stream = optional_params.pop("stream", False)
|
||||
model_response = petals.completion(
|
||||
model_response = petals_handler.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -2802,6 +2803,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
client=client,
|
||||
)
|
||||
if stream is True: ## [BETA]
|
||||
# Fake streaming for petals
|
||||
|
|
|
@ -3335,15 +3335,16 @@ def get_optional_params( # noqa: PLR0915
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# max_new_tokens=1,temperature=0.9, top_p=0.6
|
||||
if max_tokens is not None:
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
optional_params = litellm.PetalsConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "deepinfra":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -6375,6 +6376,8 @@ class ProviderConfigManager:
|
|||
return litellm.PredibaseConfig()
|
||||
elif litellm.LlmProviders.TRITON == provider:
|
||||
return litellm.TritonConfig()
|
||||
elif litellm.LlmProviders.PETALS == provider:
|
||||
return litellm.PetalsConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
|
||||
|
|
|
@ -3685,18 +3685,27 @@ def test_mistral_anyscale_stream():
|
|||
# error_str = traceback.format_exc()
|
||||
# pytest.fail(f"Error occurred: {error_str}")
|
||||
|
||||
# test_completion_with_fallbacks_multiple_keys()
|
||||
# def test_petals():
|
||||
# try:
|
||||
# response = completion(model="petals-team/StableBeluga2", messages=messages)
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
|
||||
# response = completion(model="petals-team/StableBeluga2", messages=messages)
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_with_fallbacks_multiple_keys()
|
||||
def test_petals():
|
||||
try:
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
completion(
|
||||
model="petals-team/StableBeluga2",
|
||||
messages=messages,
|
||||
client=client,
|
||||
api_base="https://api.petals.dev",
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
mock_post.assert_called_once()
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# def test_baseten():
|
||||
# try:
|
||||
|
|
|
@ -323,8 +323,13 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
|||
# or provider == LlmProviders.BEDROCK
|
||||
# or provider == LlmProviders.BASETEN
|
||||
# or provider == LlmProviders.PETALS
|
||||
# or provider == LlmProviders.SAGEMAKER
|
||||
# or provider == LlmProviders.SAGEMAKER_CHAT
|
||||
# or provider == LlmProviders.VLLM
|
||||
# or provider == LlmProviders.OLLAMA
|
||||
# ):
|
||||
# continue
|
||||
|
||||
# config = ProviderConfigManager.get_provider_chat_config(
|
||||
# model="gpt-3.5-turbo", provider=LlmProviders(provider)
|
||||
# )
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue