refactor partner models to include ai21

This commit is contained in:
Ishaan Jaff 2024-08-27 13:35:22 -07:00
parent 34eb1206c6
commit 5f61539e90
6 changed files with 150 additions and 64 deletions

View file

@ -859,9 +859,12 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
VertexAIAnthropicConfig,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
VertexAILlama3Config,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
VertexAIAi21Config,
)
from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig

View file

@ -0,0 +1,53 @@
import types
from typing import Callable, Literal, Optional, Union
import litellm
class VertexAIAi21Config:
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
-> Supports all OpenAI parameters
"""
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -0,0 +1,59 @@
import types
from typing import Callable, Literal, Optional, Union
import litellm
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -1,6 +1,7 @@
# What is this?
## Handler for calling llama 3.1 API on Vertex AI
## API Handler for calling Vertex AI Partner Models
import types
from enum import Enum
from typing import Callable, Literal, Optional, Union
import httpx # type: ignore
@ -8,7 +9,13 @@ import httpx # type: ignore
import litellm
from litellm.utils import ModelResponse
from ..base import BaseLLM
from ...base import BaseLLM
class VertexPartnerProvider(str, Enum):
mistralai = "mistralai"
llama = "llama"
ai21 = "ai21"
class VertexAIError(Exception):
@ -24,61 +31,6 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
class VertexAIPartnerModels(BaseLLM):
def __init__(self) -> None:
pass
@ -87,17 +39,22 @@ class VertexAIPartnerModels(BaseLLM):
self,
vertex_location: str,
vertex_project: str,
partner: Literal["llama", "mistralai"],
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
) -> str:
if partner == "llama":
if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == "mistralai":
elif partner == VertexPartnerProvider.mistralai:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
def completion(
self,
@ -160,9 +117,12 @@ class VertexAIPartnerModels(BaseLLM):
optional_params["stream"] = stream
if "llama" in model:
partner = "llama"
partner = VertexPartnerProvider.llama
elif "mistral" in model or "codestral" in model:
partner = "mistralai"
partner = VertexPartnerProvider.mistralai
optional_params["custom_endpoint"] = True
elif "jamba" in model:
partner = VertexPartnerProvider.ai21
optional_params["custom_endpoint"] = True
api_base = self.create_vertex_url(

View file

@ -126,7 +126,7 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
@ -2080,6 +2080,7 @@ def completion(
model.startswith("meta/")
or model.startswith("mistral")
or model.startswith("codestral")
or model.startswith("jamba")
):
model_response = vertex_partner_models_chat_completion.completion(
model=model,

View file

@ -3267,6 +3267,16 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "vertex_ai" and model in litellm.ai21_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAIAi21Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = get_supported_openai_params(