diff --git a/litellm/__init__.py b/litellm/__init__.py index 9de55edd27..13e9f3dd3f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -233,6 +233,10 @@ aleph_alpha_models: List = [ baseten_models: List = ["qvv0xeq", "q841o8w", "31dxrj3"] # FALCON 7B # WizardLM # Mosaic ML +petals_models = [ + "petals-team/StableBeluga2", +] + bedrock_models: List = [ "amazon.titan-tg1-large", "ai21.j2-grande-instruct" @@ -272,6 +276,7 @@ provider_list: List = [ "vllm", "nlp_cloud", "bedrock", + "petals," "custom", # custom apis ] @@ -287,6 +292,7 @@ models_by_provider: dict = { "vertex_ai": vertex_chat_models + vertex_text_models, "ai21": ai21_models, "bedrock": bedrock_models, + "petals": petals_models, } ####### EMBEDDING MODELS ################### diff --git a/litellm/main.py b/litellm/main.py index f793232368..eb209db7ef 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -34,6 +34,7 @@ from .llms import baseten from .llms import vllm from .llms import ollama from .llms import cohere +from .llms import petals import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict @@ -953,6 +954,32 @@ def completion( ) return response response = model_response + elif ( + custom_llm_provider == "petals" + or custom_llm_provider == "petals-team" + or model in litellm.petals_models + ): + custom_llm_provider = "baseten" + + model_response = petals.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=baseten_key, + logging_obj=logging + ) + if inspect.isgenerator(model_response) or (stream == True): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="petals", logging_obj=logging + ) + return response + response = model_response elif ( custom_llm_provider == "custom" ):