from typing import Optional, List
from pydantic import BaseModel
from fma.toolkit import model as fma_model
from fma.toolkit.fields.image import Image
class Input(BaseModel):
prompt: str
width: Optional[int]
height: Optional[int]
num_inference_steps: Optional[int] = 4
guidance_scale: Optional[float] = 0.0
class Output(BaseModel):
image: Image
class Model(fma_model.Model):
requirements: List[str] = [
"torch==2.5.0",
"diffusers",
"bitsandbytes==0.44.1",
"transformers==4.45.2",
"tokenizers==0.20.1",
"sentencepiece==0.2.0",
"accelerate==1.0.1",
"numpy>=1.26.4",
"Pillow",
"protobuf",
]
def initialize(self):
import os
os.environ['HF_TOKEN'] = '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'
from diffusers import BitsAndBytesConfig, SD3Transformer2DModel
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel
import torch
model_id = "stabilityai/stable-diffusion-3.5-large"
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model_nf4 = SD3Transformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=nf4_config,
torch_dtype=torch.bfloat16
)
t5_nf4 = T5EncoderModel.from_pretrained("diffusers/t5-nf4", torch_dtype=torch.bfloat16)
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_id,
transformer=model_nf4,
text_encoder_3=t5_nf4,
torch_dtype=torch.bfloat16
)
def predict(self, input: Input) -> Output:
image = self.pipeline(
prompt=input["prompt"],
num_inference_steps=input.get("num_inference_steps", 4),
guidance_scale=input.get("guidance_scale", 0.0),
max_sequence_length=512,
).images[0]
return {"image": image}