Skip to content

Commit

Permalink
Flux Autoencoder (#2098)
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinpelletier authored Jan 8, 2025
1 parent 38bf427 commit cce8ef6
Show file tree
Hide file tree
Showing 6 changed files with 627 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/torchtune/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
81 changes: 81 additions & 0 deletions tests/torchtune/models/flux/test_flux_autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchtune.models.flux import flux_1_autoencoder
from torchtune.training.seed import set_seed

BSZ = 32
CH_IN = 3
RESOLUTION = 16
CH_MULTS = [1, 2]
CH_Z = 4
RES_Z = RESOLUTION // len(CH_MULTS)


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestFluxAutoencoder:
@pytest.fixture
def model(self):
model = flux_1_autoencoder(
resolution=RESOLUTION,
ch_in=CH_IN,
ch_out=3,
ch_base=32,
ch_mults=CH_MULTS,
ch_z=CH_Z,
n_layers_per_resample_block=2,
scale_factor=1.0,
shift_factor=0.0,
)

for param in model.parameters():
param.data.uniform_(0, 0.1)

return model

@pytest.fixture
def img(self):
return torch.randn(BSZ, CH_IN, RESOLUTION, RESOLUTION)

@pytest.fixture
def z(self):
return torch.randn(BSZ, CH_Z, RES_Z, RES_Z)

def test_forward(self, model, img):
actual = model(img)
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.4286, 0.4276, 0.4054])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_backward(self, model, img):
y = model(img)
loss = y.mean()
loss.backward()

def test_encode(self, model, img):
actual = model.encode(img)
assert actual.shape == (BSZ, CH_Z, RES_Z, RES_Z)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.6150, 0.7959, 0.7178, 0.7011])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_decode(self, model, z):
actual = model.decode(z)
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.4246, 0.4241, 0.4014])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
10 changes: 10 additions & 0 deletions torchtune/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from ._model_builders import flux_1_autoencoder

__all__ = [
"flux_1_autoencoder",
]
Loading

0 comments on commit cce8ef6

Please sign in to comment.