mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
build: Squashed commit of https://github.com/BerriAI/litellm/pull/7171
Closes https://github.com/BerriAI/litellm/pull/7171
This commit is contained in:
parent
efbec4230b
commit
02dd0c6e7e
9 changed files with 209 additions and 107 deletions
|
@ -1084,8 +1084,8 @@ from .llms.deprecated_providers.palm import (
|
||||||
PalmConfig,
|
PalmConfig,
|
||||||
) # here to prevent breaking changes
|
) # here to prevent breaking changes
|
||||||
from .llms.nlp_cloud.chat.handler import NLPCloudConfig
|
from .llms.nlp_cloud.chat.handler import NLPCloudConfig
|
||||||
|
from .llms.petals.completion.transformation import PetalsConfig
|
||||||
from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
|
from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
|
||||||
from .llms.petals import PetalsConfig
|
|
||||||
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexGeminiConfig,
|
VertexGeminiConfig,
|
||||||
GoogleAIStudioGeminiConfig,
|
GoogleAIStudioGeminiConfig,
|
||||||
|
|
|
@ -172,7 +172,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
elif custom_llm_provider == "nlp_cloud":
|
elif custom_llm_provider == "nlp_cloud":
|
||||||
return litellm.NLPCloudConfig().get_supported_openai_params(model=model)
|
return litellm.NLPCloudConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "petals":
|
elif custom_llm_provider == "petals":
|
||||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
return litellm.PetalsConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "deepinfra":
|
elif custom_llm_provider == "deepinfra":
|
||||||
return litellm.DeepInfraConfig().get_supported_openai_params(model=model)
|
return litellm.DeepInfraConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "perplexity":
|
elif custom_llm_provider == "perplexity":
|
||||||
|
|
10
litellm/llms/petals/common_utils.py
Normal file
10
litellm/llms/petals/common_utils.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from httpx import Headers
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class PetalsError(BaseLLMException):
|
||||||
|
def __init__(self, status_code: int, message: str, headers: Union[dict, Headers]):
|
||||||
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
@ -3,91 +3,21 @@ import os
|
||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
custom_prompt,
|
||||||
|
prompt_factory,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
)
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
from ..common_utils import PetalsError
|
||||||
|
|
||||||
|
|
||||||
class PetalsError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class PetalsConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
|
|
||||||
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
|
|
||||||
|
|
||||||
- `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
|
|
||||||
|
|
||||||
- `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
|
|
||||||
|
|
||||||
The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
|
|
||||||
|
|
||||||
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
|
|
||||||
|
|
||||||
- `temperature` (float, optional): This value sets the temperature for sampling.
|
|
||||||
|
|
||||||
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
|
|
||||||
|
|
||||||
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
|
|
||||||
|
|
||||||
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_length: Optional[int] = None
|
|
||||||
max_new_tokens: Optional[int] = (
|
|
||||||
litellm.max_tokens
|
|
||||||
) # petals requires max tokens to be set
|
|
||||||
do_sample: Optional[bool] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
repetition_penalty: Optional[float] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_new_tokens: Optional[
|
|
||||||
int
|
|
||||||
] = litellm.max_tokens, # petals requires max tokens to be set
|
|
||||||
do_sample: Optional[bool] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
repetition_penalty: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -102,6 +32,7 @@ def completion(
|
||||||
stream=False,
|
stream=False,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.PetalsConfig.get_config()
|
config = litellm.PetalsConfig.get_config()
|
||||||
|
@ -137,7 +68,9 @@ def completion(
|
||||||
data = {"model": model, "inputs": prompt, **optional_params}
|
data = {"model": model, "inputs": prompt, **optional_params}
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = requests.post(api_base, data=data)
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = _get_httpx_client()
|
||||||
|
response = client.post(api_base, data=data)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -151,7 +84,11 @@ def completion(
|
||||||
try:
|
try:
|
||||||
output_text = response.json()["outputs"]
|
output_text = response.json()["outputs"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
PetalsError(status_code=response.status_code, message=str(e))
|
PetalsError(
|
||||||
|
status_code=response.status_code,
|
||||||
|
message=str(e),
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
136
litellm/llms/petals/completion/transformation.py
Normal file
136
litellm/llms/petals/completion/transformation.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
import types
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
from httpx import Headers, Response
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.transformation import (
|
||||||
|
BaseConfig,
|
||||||
|
BaseLLMException,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
from ..common_utils import PetalsError
|
||||||
|
|
||||||
|
|
||||||
|
class PetalsConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
|
||||||
|
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
|
||||||
|
|
||||||
|
- `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
|
||||||
|
|
||||||
|
- `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
|
||||||
|
|
||||||
|
The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
|
||||||
|
|
||||||
|
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
|
||||||
|
|
||||||
|
- `temperature` (float, optional): This value sets the temperature for sampling.
|
||||||
|
|
||||||
|
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
|
||||||
|
|
||||||
|
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
|
||||||
|
|
||||||
|
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_length: Optional[int] = None
|
||||||
|
max_new_tokens: Optional[int] = (
|
||||||
|
litellm.max_tokens
|
||||||
|
) # petals requires max tokens to be set
|
||||||
|
do_sample: Optional[bool] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
max_new_tokens: Optional[
|
||||||
|
int
|
||||||
|
] = litellm.max_tokens, # petals requires max tokens to be set
|
||||||
|
do_sample: Optional[bool] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return PetalsError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
|
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: litellm.ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> litellm.ModelResponse:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
return {}
|
|
@ -94,7 +94,7 @@ from .litellm_core_utils.prompt_templates.factory import (
|
||||||
stringify_json_tool_call_content,
|
stringify_json_tool_call_content,
|
||||||
)
|
)
|
||||||
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||||
from .llms import baseten, maritalk, ollama_chat, petals
|
from .llms import baseten, maritalk, ollama_chat
|
||||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
||||||
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||||
|
@ -120,6 +120,7 @@ from .llms.openai.openai import OpenAIChatCompletion
|
||||||
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
||||||
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
|
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||||
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||||
|
from .llms.petals.completion import handler as petals_handler
|
||||||
from .llms.predibase.chat.handler import PredibaseChatCompletion
|
from .llms.predibase.chat.handler import PredibaseChatCompletion
|
||||||
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||||
|
@ -2791,7 +2792,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
|
|
||||||
custom_llm_provider = "petals"
|
custom_llm_provider = "petals"
|
||||||
stream = optional_params.pop("stream", False)
|
stream = optional_params.pop("stream", False)
|
||||||
model_response = petals.completion(
|
model_response = petals_handler.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -2802,6 +2803,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
if stream is True: ## [BETA]
|
if stream is True: ## [BETA]
|
||||||
# Fake streaming for petals
|
# Fake streaming for petals
|
||||||
|
|
|
@ -3335,15 +3335,16 @@ def get_optional_params( # noqa: PLR0915
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# max_new_tokens=1,temperature=0.9, top_p=0.6
|
optional_params = litellm.PetalsConfig().map_openai_params(
|
||||||
if max_tokens is not None:
|
non_default_params=non_default_params,
|
||||||
optional_params["max_new_tokens"] = max_tokens
|
optional_params=optional_params,
|
||||||
if temperature is not None:
|
model=model,
|
||||||
optional_params["temperature"] = temperature
|
drop_params=(
|
||||||
if top_p is not None:
|
drop_params
|
||||||
optional_params["top_p"] = top_p
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
if stream:
|
else False
|
||||||
optional_params["stream"] = stream
|
),
|
||||||
|
)
|
||||||
elif custom_llm_provider == "deepinfra":
|
elif custom_llm_provider == "deepinfra":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -6375,6 +6376,8 @@ class ProviderConfigManager:
|
||||||
return litellm.PredibaseConfig()
|
return litellm.PredibaseConfig()
|
||||||
elif litellm.LlmProviders.TRITON == provider:
|
elif litellm.LlmProviders.TRITON == provider:
|
||||||
return litellm.TritonConfig()
|
return litellm.TritonConfig()
|
||||||
|
elif litellm.LlmProviders.PETALS == provider:
|
||||||
|
return litellm.PetalsConfig()
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3685,18 +3685,27 @@ def test_mistral_anyscale_stream():
|
||||||
# error_str = traceback.format_exc()
|
# error_str = traceback.format_exc()
|
||||||
# pytest.fail(f"Error occurred: {error_str}")
|
# pytest.fail(f"Error occurred: {error_str}")
|
||||||
|
|
||||||
# test_completion_with_fallbacks_multiple_keys()
|
|
||||||
# def test_petals():
|
|
||||||
# try:
|
|
||||||
# response = completion(model="petals-team/StableBeluga2", messages=messages)
|
|
||||||
# # Add any assertions here to check the response
|
|
||||||
# print(response)
|
|
||||||
|
|
||||||
# response = completion(model="petals-team/StableBeluga2", messages=messages)
|
# test_completion_with_fallbacks_multiple_keys()
|
||||||
# # Add any assertions here to check the response
|
def test_petals():
|
||||||
# print(response)
|
try:
|
||||||
# except Exception as e:
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
completion(
|
||||||
|
model="petals-team/StableBeluga2",
|
||||||
|
messages=messages,
|
||||||
|
client=client,
|
||||||
|
api_base="https://api.petals.dev",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# def test_baseten():
|
# def test_baseten():
|
||||||
# try:
|
# try:
|
||||||
|
|
|
@ -323,8 +323,13 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
||||||
# or provider == LlmProviders.BEDROCK
|
# or provider == LlmProviders.BEDROCK
|
||||||
# or provider == LlmProviders.BASETEN
|
# or provider == LlmProviders.BASETEN
|
||||||
# or provider == LlmProviders.PETALS
|
# or provider == LlmProviders.PETALS
|
||||||
|
# or provider == LlmProviders.SAGEMAKER
|
||||||
|
# or provider == LlmProviders.SAGEMAKER_CHAT
|
||||||
|
# or provider == LlmProviders.VLLM
|
||||||
|
# or provider == LlmProviders.OLLAMA
|
||||||
# ):
|
# ):
|
||||||
# continue
|
# continue
|
||||||
|
|
||||||
# config = ProviderConfigManager.get_provider_chat_config(
|
# config = ProviderConfigManager.get_provider_chat_config(
|
||||||
# model="gpt-3.5-turbo", provider=LlmProviders(provider)
|
# model="gpt-3.5-turbo", provider=LlmProviders(provider)
|
||||||
# )
|
# )
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue