refactor partner models to include ai21

This commit is contained in:
Ishaan Jaff 2024-08-27 13:35:22 -07:00
parent 415abc86c6
commit 11c175a215
6 changed files with 150 additions and 64 deletions

View file

@ -859,9 +859,12 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
VertexAIAnthropicConfig, VertexAIAnthropicConfig,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
VertexAILlama3Config, VertexAILlama3Config,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
VertexAIAi21Config,
)
from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig

View file

@ -0,0 +1,53 @@
import types
from typing import Callable, Literal, Optional, Union
import litellm
class VertexAIAi21Config:
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
-> Supports all OpenAI parameters
"""
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -0,0 +1,59 @@
import types
from typing import Callable, Literal, Optional, Union
import litellm
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -1,6 +1,7 @@
# What is this? # What is this?
## Handler for calling llama 3.1 API on Vertex AI ## API Handler for calling Vertex AI Partner Models
import types import types
from enum import Enum
from typing import Callable, Literal, Optional, Union from typing import Callable, Literal, Optional, Union
import httpx # type: ignore import httpx # type: ignore
@ -8,7 +9,13 @@ import httpx # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
from ..base import BaseLLM from ...base import BaseLLM
class VertexPartnerProvider(str, Enum):
mistralai = "mistralai"
llama = "llama"
ai21 = "ai21"
class VertexAIError(Exception): class VertexAIError(Exception):
@ -24,61 +31,6 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
class VertexAIPartnerModels(BaseLLM): class VertexAIPartnerModels(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
@ -87,17 +39,22 @@ class VertexAIPartnerModels(BaseLLM):
self, self,
vertex_location: str, vertex_location: str,
vertex_project: str, vertex_project: str,
partner: Literal["llama", "mistralai"], partner: VertexPartnerProvider,
stream: Optional[bool], stream: Optional[bool],
model: str, model: str,
) -> str: ) -> str:
if partner == "llama": if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == "mistralai": elif partner == VertexPartnerProvider.mistralai:
if stream: if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict" return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else: else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict" return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
def completion( def completion(
self, self,
@ -160,9 +117,12 @@ class VertexAIPartnerModels(BaseLLM):
optional_params["stream"] = stream optional_params["stream"] = stream
if "llama" in model: if "llama" in model:
partner = "llama" partner = VertexPartnerProvider.llama
elif "mistral" in model or "codestral" in model: elif "mistral" in model or "codestral" in model:
partner = "mistralai" partner = VertexPartnerProvider.mistralai
optional_params["custom_endpoint"] = True
elif "jamba" in model:
partner = VertexPartnerProvider.ai21
optional_params["custom_endpoint"] = True optional_params["custom_endpoint"] = True
api_base = self.create_vertex_url( api_base = self.create_vertex_url(

View file

@ -126,7 +126,7 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic, vertex_ai_anthropic,
vertex_ai_non_gemini, vertex_ai_non_gemini,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels, VertexAIPartnerModels,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
@ -2080,6 +2080,7 @@ def completion(
model.startswith("meta/") model.startswith("meta/")
or model.startswith("mistral") or model.startswith("mistral")
or model.startswith("codestral") or model.startswith("codestral")
or model.startswith("jamba")
): ):
model_response = vertex_partner_models_chat_completion.completion( model_response = vertex_partner_models_chat_completion.completion(
model=model, model=model,

View file

@ -3267,6 +3267,16 @@ def get_optional_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
) )
elif custom_llm_provider == "vertex_ai" and model in litellm.ai21_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAIAi21Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(