mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
refactor partner models to include ai21
This commit is contained in:
parent
34eb1206c6
commit
5f61539e90
6 changed files with 150 additions and 64 deletions
|
@ -859,9 +859,12 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
|||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
||||
VertexAIAnthropicConfig,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
|
||||
VertexAILlama3Config,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
|
||||
VertexAIAi21Config,
|
||||
)
|
||||
from .llms.sagemaker.sagemaker import SagemakerConfig
|
||||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
import types
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class VertexAIAi21Config:
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
|
||||
|
||||
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
|
||||
|
||||
-> Supports all OpenAI parameters
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
return litellm.OpenAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
|
@ -0,0 +1,59 @@
|
|||
import types
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class VertexAILlama3Config:
|
||||
"""
|
||||
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
|
||||
|
||||
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
return litellm.OpenAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
|
@ -1,6 +1,7 @@
|
|||
# What is this?
|
||||
## Handler for calling llama 3.1 API on Vertex AI
|
||||
## API Handler for calling Vertex AI Partner Models
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
|
@ -8,7 +9,13 @@ import httpx # type: ignore
|
|||
import litellm
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ..base import BaseLLM
|
||||
from ...base import BaseLLM
|
||||
|
||||
|
||||
class VertexPartnerProvider(str, Enum):
|
||||
mistralai = "mistralai"
|
||||
llama = "llama"
|
||||
ai21 = "ai21"
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -24,61 +31,6 @@ class VertexAIError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexAILlama3Config:
|
||||
"""
|
||||
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
|
||||
|
||||
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
return litellm.OpenAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIPartnerModels(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
@ -87,17 +39,22 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
self,
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
partner: Literal["llama", "mistralai"],
|
||||
partner: VertexPartnerProvider,
|
||||
stream: Optional[bool],
|
||||
model: str,
|
||||
) -> str:
|
||||
if partner == "llama":
|
||||
if partner == VertexPartnerProvider.llama:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
||||
elif partner == "mistralai":
|
||||
elif partner == VertexPartnerProvider.mistralai:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||
elif partner == VertexPartnerProvider.ai21:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
|
||||
|
||||
def completion(
|
||||
self,
|
||||
|
@ -160,9 +117,12 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
optional_params["stream"] = stream
|
||||
|
||||
if "llama" in model:
|
||||
partner = "llama"
|
||||
partner = VertexPartnerProvider.llama
|
||||
elif "mistral" in model or "codestral" in model:
|
||||
partner = "mistralai"
|
||||
partner = VertexPartnerProvider.mistralai
|
||||
optional_params["custom_endpoint"] = True
|
||||
elif "jamba" in model:
|
||||
partner = VertexPartnerProvider.ai21
|
||||
optional_params["custom_endpoint"] = True
|
||||
|
||||
api_base = self.create_vertex_url(
|
|
@ -126,7 +126,7 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
|||
vertex_ai_anthropic,
|
||||
vertex_ai_non_gemini,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||
VertexAIPartnerModels,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
|
@ -2080,6 +2080,7 @@ def completion(
|
|||
model.startswith("meta/")
|
||||
or model.startswith("mistral")
|
||||
or model.startswith("codestral")
|
||||
or model.startswith("jamba")
|
||||
):
|
||||
model_response = vertex_partner_models_chat_completion.completion(
|
||||
model=model,
|
||||
|
|
|
@ -3267,6 +3267,16 @@ def get_optional_params(
|
|||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai" and model in litellm.ai21_models:
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.VertexAIAi21Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue