forked from phoenix/litellm-mirror
added async support for bedrock image gen
This commit is contained in:
parent
64c3c4906c
commit
092888d593
5 changed files with 300 additions and 108 deletions
|
@ -984,10 +984,10 @@ from .llms.bedrock.common_utils import (
|
||||||
AmazonAnthropicClaude3Config,
|
AmazonAnthropicClaude3Config,
|
||||||
AmazonCohereConfig,
|
AmazonCohereConfig,
|
||||||
AmazonLlamaConfig,
|
AmazonLlamaConfig,
|
||||||
AmazonStabilityConfig,
|
|
||||||
AmazonMistralConfig,
|
AmazonMistralConfig,
|
||||||
AmazonBedrockGlobalConfig,
|
AmazonBedrockGlobalConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
|
||||||
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
|
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||||
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
|
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
|
||||||
AmazonTitanMultimodalEmbeddingG1Config,
|
AmazonTitanMultimodalEmbeddingG1Config,
|
||||||
|
|
|
@ -484,73 +484,6 @@ class AmazonMistralConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AmazonStabilityConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
|
||||||
|
|
||||||
Supported Params for the Amazon / Stable Diffusion models:
|
|
||||||
|
|
||||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
|
||||||
|
|
||||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
|
||||||
|
|
||||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
|
||||||
|
|
||||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
|
||||||
Engine-specific dimension validation:
|
|
||||||
|
|
||||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
|
||||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
|
||||||
- SDXL v1.0: same as SDXL v0.9
|
|
||||||
- SD v1.6: must be between 320x320 and 1536x1536
|
|
||||||
|
|
||||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
|
||||||
Engine-specific dimension validation:
|
|
||||||
|
|
||||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
|
||||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
|
||||||
- SDXL v1.0: same as SDXL v0.9
|
|
||||||
- SD v1.6: must be between 320x320 and 1536x1536
|
|
||||||
"""
|
|
||||||
|
|
||||||
cfg_scale: Optional[int] = None
|
|
||||||
seed: Optional[float] = None
|
|
||||||
steps: Optional[List[str]] = None
|
|
||||||
width: Optional[int] = None
|
|
||||||
height: Optional[int] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cfg_scale: Optional[int] = None,
|
|
||||||
seed: Optional[float] = None,
|
|
||||||
steps: Optional[List[str]] = None,
|
|
||||||
width: Optional[int] = None,
|
|
||||||
height: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def add_custom_header(headers):
|
def add_custom_header(headers):
|
||||||
"""Closure to capture the headers and add them."""
|
"""Closure to capture the headers and add them."""
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
import types
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonStabilityConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||||
|
|
||||||
|
Supported Params for the Amazon / Stable Diffusion models:
|
||||||
|
|
||||||
|
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||||
|
|
||||||
|
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||||
|
|
||||||
|
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||||
|
|
||||||
|
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||||
|
Engine-specific dimension validation:
|
||||||
|
|
||||||
|
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||||
|
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||||
|
- SDXL v1.0: same as SDXL v0.9
|
||||||
|
- SD v1.6: must be between 320x320 and 1536x1536
|
||||||
|
|
||||||
|
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||||
|
Engine-specific dimension validation:
|
||||||
|
|
||||||
|
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||||
|
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||||
|
- SDXL v1.0: same as SDXL v0.9
|
||||||
|
- SD v1.6: must be between 320x320 and 1536x1536
|
||||||
|
"""
|
||||||
|
|
||||||
|
cfg_scale: Optional[int] = None
|
||||||
|
seed: Optional[float] = None
|
||||||
|
steps: Optional[List[str]] = None
|
||||||
|
width: Optional[int] = None
|
||||||
|
height: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg_scale: Optional[int] = None,
|
||||||
|
seed: Optional[float] = None,
|
||||||
|
steps: Optional[List[str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
|
@ -1,38 +1,163 @@
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, List, Optional
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai.types.image import Image
|
from openai.types.image import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, _get_httpx_client
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.types.utils import ImageResponse
|
from litellm.types.utils import ImageResponse
|
||||||
from litellm.utils import print_verbose
|
|
||||||
|
|
||||||
from ...base_aws_llm import BaseAWSLLM
|
from ...base_aws_llm import BaseAWSLLM
|
||||||
from ..common_utils import BedrockError
|
from ..common_utils import BedrockError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from botocore.awsrequest import AWSPreparedRequest
|
||||||
|
else:
|
||||||
|
AWSPreparedRequest = Any
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockImagePreparedRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Internal/Helper class for preparing the request for bedrock image generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint_url: str
|
||||||
|
prepped: AWSPreparedRequest
|
||||||
|
body: bytes
|
||||||
|
data: dict
|
||||||
|
|
||||||
|
|
||||||
class BedrockImageGeneration(BaseAWSLLM):
|
class BedrockImageGeneration(BaseAWSLLM):
|
||||||
"""
|
"""
|
||||||
Bedrock Image Generation handler
|
Bedrock Image Generation handler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def image_generation( # noqa: PLR0915
|
def image_generation(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_response: ImageResponse,
|
model_response: ImageResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
logging_obj: Any,
|
logging_obj: LitellmLogging,
|
||||||
timeout=None,
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
aimg_generation: bool = False,
|
aimg_generation: bool = False,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Any] = None,
|
|
||||||
):
|
):
|
||||||
|
prepared_request = self._prepare_request(
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_base=api_base,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if aimg_generation is True:
|
||||||
|
return self.async_image_generation(
|
||||||
|
prepared_request=prepared_request,
|
||||||
|
timeout=timeout,
|
||||||
|
model=model,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
prompt=prompt,
|
||||||
|
model_response=model_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = _get_httpx_client()
|
||||||
|
try:
|
||||||
|
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||||
|
model_response = self._transform_response_dict_to_openai_response(
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
data=prepared_request.data,
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
async def async_image_generation(
|
||||||
|
self,
|
||||||
|
prepared_request: BedrockImagePreparedRequest,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
model: str,
|
||||||
|
logging_obj: LitellmLogging,
|
||||||
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
) -> ImageResponse:
|
||||||
|
"""
|
||||||
|
Asynchronous handler for bedrock image generation
|
||||||
|
|
||||||
|
Awaits the response from the bedrock image generation endpoint
|
||||||
|
"""
|
||||||
|
async_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.BEDROCK,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||||
|
model_response = self._transform_response_dict_to_openai_response(
|
||||||
|
model=model,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
data=prepared_request.data,
|
||||||
|
model_response=model_response,
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
api_base: Optional[str],
|
||||||
|
extra_headers: Optional[dict],
|
||||||
|
logging_obj: LitellmLogging,
|
||||||
|
prompt: str,
|
||||||
|
) -> BedrockImagePreparedRequest:
|
||||||
|
"""
|
||||||
|
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The model to use for the image generation
|
||||||
|
optional_params (dict): The optional parameters for the image generation
|
||||||
|
api_base (Optional[str]): The base URL for the Bedrock API
|
||||||
|
extra_headers (Optional[dict]): The extra headers to include in the request
|
||||||
|
logging_obj (LitellmLogging): The logging object to use for logging
|
||||||
|
prompt (str): The prompt to use for the image generation
|
||||||
|
Returns:
|
||||||
|
BedrockImagePreparedRequest: The prepared request object
|
||||||
|
|
||||||
|
The BedrockImagePreparedRequest contains:
|
||||||
|
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
|
||||||
|
prepped (httpx.Request): The prepared request object
|
||||||
|
body (bytes): The request body
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.auth import SigV4Auth
|
from botocore.auth import SigV4Auth
|
||||||
|
@ -46,7 +171,7 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
modelId = model
|
modelId = model
|
||||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||||
|
@ -107,27 +232,25 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
"headers": prepped.headers,
|
"headers": prepped.headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
return BedrockImagePreparedRequest(
|
||||||
|
endpoint_url=proxy_endpoint_url,
|
||||||
|
prepped=prepped,
|
||||||
|
body=body,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
def _transform_response_dict_to_openai_response(
|
||||||
_params = {}
|
self,
|
||||||
if timeout is not None:
|
model_response: ImageResponse,
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
model: str,
|
||||||
timeout = httpx.Timeout(timeout)
|
logging_obj: LitellmLogging,
|
||||||
_params["timeout"] = timeout
|
prompt: str,
|
||||||
client = _get_httpx_client(_params) # type: ignore
|
response: httpx.Response,
|
||||||
else:
|
data: dict,
|
||||||
client = client
|
) -> ImageResponse:
|
||||||
|
"""
|
||||||
try:
|
Transforms the Image Generation response from Bedrock to OpenAI format
|
||||||
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=body) # type: ignore
|
"""
|
||||||
response.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as err:
|
|
||||||
error_code = err.response.status_code
|
|
||||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
|
||||||
|
|
||||||
response_body = response.json()
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
if logging_obj is not None:
|
if logging_obj is not None:
|
||||||
|
@ -137,22 +260,16 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
original_response=response.text,
|
original_response=response.text,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
print_verbose("raw model_response: %s", response.text)
|
verbose_logger.debug("raw model_response: %s", response.text)
|
||||||
|
response_dict = response.json()
|
||||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
if response_dict is None:
|
||||||
if response_body is None:
|
raise ValueError("Error in response object format, got None")
|
||||||
raise Exception("Error in response object format")
|
|
||||||
|
|
||||||
if model_response is None:
|
|
||||||
model_response = ImageResponse()
|
|
||||||
|
|
||||||
image_list: List[Image] = []
|
image_list: List[Image] = []
|
||||||
for artifact in response_body["artifacts"]:
|
for artifact in response_dict["artifacts"]:
|
||||||
_image = Image(b64_json=artifact["base64"])
|
_image = Image(b64_json=artifact["base64"])
|
||||||
image_list.append(_image)
|
image_list.append(_image)
|
||||||
|
|
||||||
model_response.data = image_list
|
model_response.data = image_list
|
||||||
return model_response
|
|
||||||
|
|
||||||
async def async_image_generation(self):
|
return model_response
|
||||||
pass
|
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import types
|
||||||
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonStability1Config:
|
||||||
|
"""
|
||||||
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||||
|
|
||||||
|
Supported Params for the Amazon / Stable Diffusion models:
|
||||||
|
|
||||||
|
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||||
|
|
||||||
|
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||||
|
|
||||||
|
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||||
|
|
||||||
|
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||||
|
Engine-specific dimension validation:
|
||||||
|
|
||||||
|
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||||
|
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||||
|
- SDXL v1.0: same as SDXL v0.9
|
||||||
|
- SD v1.6: must be between 320x320 and 1536x1536
|
||||||
|
|
||||||
|
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||||
|
Engine-specific dimension validation:
|
||||||
|
|
||||||
|
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||||
|
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||||
|
- SDXL v1.0: same as SDXL v0.9
|
||||||
|
- SD v1.6: must be between 320x320 and 1536x1536
|
||||||
|
"""
|
||||||
|
|
||||||
|
cfg_scale: Optional[int] = None
|
||||||
|
seed: Optional[float] = None
|
||||||
|
steps: Optional[List[str]] = None
|
||||||
|
width: Optional[int] = None
|
||||||
|
height: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg_scale: Optional[int] = None,
|
||||||
|
seed: Optional[float] = None,
|
||||||
|
steps: Optional[List[str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue