Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ability to read images from numpy files and numpy arrays #2212

Merged
merged 11 commits into from
Jul 11, 2022
12 changes: 9 additions & 3 deletions ludwig/data/dataset_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def build_feature_parameters(features):


def build_synthetic_dataset(dataset_size: int, features: List[dict]):
"""Symthesizes a dataset for testing purposes.
"""Synthesizes a dataset for testing purposes.

:param dataset_size: (int) size of the dataset
:param features: (List[dict]) list of features to generate in YAML format.
Expand Down Expand Up @@ -280,7 +280,9 @@ def generate_audio(feature):
return audio_dest_path


def generate_image(feature):
def generate_image(feature, save_as_numpy=False):
save_as_numpy = feature.get("save_as_numpy", save_as_numpy)

try:
from torchvision.io import write_png
except ImportError:
Expand Down Expand Up @@ -318,7 +320,11 @@ def generate_image(feature):

image_dest_path = os.path.join(destination_folder, image_filename)
# save_image(torch.from_numpy(img.astype("uint8")), image_dest_path)
write_png(img, image_dest_path)
if save_as_numpy:
with open(image_dest_path, "wb") as f:
np.save(f, img.detach().cpu().numpy())
else:
write_png(img, image_dest_path)

except OSError as e:
raise OSError("Unable to create a folder for images/save image to disk." "{}".format(e))
Expand Down
19 changes: 16 additions & 3 deletions ludwig/features/image_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_feature_meta(column, preprocessing_parameters, backend):

@staticmethod
def _read_image_if_bytes_obj_and_resize(
img_entry: Union[bytes, torch.Tensor],
img_entry: Union[bytes, torch.Tensor, np.ndarray],
img_width: int,
img_height: int,
should_resize: bool,
Expand All @@ -157,7 +157,7 @@ def _read_image_if_bytes_obj_and_resize(
user_specified_num_channels: bool,
) -> Optional[np.ndarray]:
"""
:param img_entry Union[bytes, torch.Tensor]: if str file path to the
:param img_entry Union[bytes, torch.Tensor, np.ndarray]: if str file path to the
image else torch.Tensor of the image itself
:param img_width: expected width of the image
:param img_height: expected height of the image
Expand All @@ -176,8 +176,11 @@ def _read_image_if_bytes_obj_and_resize(
If the user specifies a number of channels, we try to convert all the
images to the specifications by dropping channels/padding 0 channels
"""

if isinstance(img_entry, bytes):
img = read_image_from_bytes_obj(img_entry, num_channels)
elif isinstance(img_entry, np.ndarray):
img = torch.from_numpy(img_entry).permute(2, 0, 1)
else:
img = img_entry

Expand Down Expand Up @@ -315,16 +318,26 @@ def _finalize_preprocessing_parameters(
else:
sample_size = 1 # Take first image

failed_entries = []
for image_entry in column.head(sample_size):
if isinstance(image_entry, str):
# Tries to read image as PNG or numpy file from the path.
image = read_image_from_path(image_entry)
else:
image = image_entry

if isinstance(image, torch.Tensor):
sample.append(image)
elif isinstance(image, np.ndarray):
sample.append(torch.from_numpy(image).permute(2, 0, 1))
else:
failed_entries.append(image_entry)
if len(sample) == 0:
raise ValueError("No readable images in sample, image dimensions cannot be inferred")
failed_entries_repr = "\n\t- ".join(failed_entries)
raise ValueError(
f"Images dimensions cannot be inferred. Failed to read {sample_size} images as samples:\n\t- "
f"{failed_entries_repr}."
)

should_resize = False
if explicit_height_width:
Expand Down
29 changes: 28 additions & 1 deletion ludwig/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,24 @@ def read_image_from_path(path: str, num_channels: Optional[int] = None) -> Optio
def read_image_from_bytes_obj(
bytes_obj: Optional[bytes] = None, num_channels: Optional[int] = None
) -> Optional[torch.Tensor]:
"""Tries to read image as a tensor from the path.

If the path is not decodable as a PNG, attempts to read as a numpy file. If neither of these work, returns None.
"""
mode = get_image_read_mode_from_num_channels(num_channels)

image = read_image_as_png(bytes_obj, mode)
if image is None:
image = read_image_as_numpy(bytes_obj)
if image is None:
logger.warning("Unable to read image from bytes object.")
return image


def read_image_as_png(
bytes_obj: Optional[bytes] = None, mode: ImageReadMode = ImageReadMode.UNCHANGED
) -> Optional[torch.Tensor]:
"""Reads image from bytes object from a PNG file."""
try:
with BytesIO(bytes_obj) as buffer:
buffer_view = buffer.getbuffer()
Expand All @@ -107,7 +123,18 @@ def read_image_from_bytes_obj(
del buffer_view
return image
except Exception as e:
logger.warning("Failed to read image from bytes object. Original exception: " + str(e))
logger.warning(f"Failed to read image from PNG file. Original exception: {e}")
return None


def read_image_as_numpy(bytes_obj: Optional[bytes] = None) -> Optional[torch.Tensor]:
"""Reads image from bytes object from a numpy file."""
try:
with BytesIO(bytes_obj) as buffer:
image = np.load(buffer)
return torch.from_numpy(image)
except Exception as e:
logger.warning(f"Failed to read image from numpy file. Original exception: {e}")
return None


Expand Down
64 changes: 61 additions & 3 deletions tests/integration_tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import numpy as np
import pandas as pd
import pytest
from PIL import Image

from ludwig.api import LudwigModel
from ludwig.constants import COLUMN, PROC_COLUMN
from ludwig.constants import COLUMN, NAME, PROC_COLUMN, TRAINER
from ludwig.data.concatenate_datasets import concatenate_df
from tests.integration_tests.utils import (
audio_feature,
Expand All @@ -21,6 +22,8 @@
sequence_feature,
)

NUM_EXAMPLES = 10


@pytest.mark.parametrize("backend", ["local", "ray"])
@pytest.mark.distributed
Expand Down Expand Up @@ -80,7 +83,7 @@ def test_strip_whitespace_category(csv_filename, tmpdir):
@pytest.mark.parametrize("backend", ["local", "ray"])
@pytest.mark.distributed
def test_with_split(backend, csv_filename, tmpdir):
num_examples = 10
num_examples = NUM_EXAMPLES
train_set_size = int(num_examples * 0.8)
val_set_size = int(num_examples * 0.1)
test_set_size = int(num_examples * 0.1)
Expand Down Expand Up @@ -118,7 +121,7 @@ def test_with_split(backend, csv_filename, tmpdir):
def test_dask_known_divisions(feature_fn, csv_filename, tmpdir):
import dask.dataframe as dd

num_examples = 10
num_examples = NUM_EXAMPLES

input_features = [feature_fn(os.path.join(tmpdir, "generated_output"))]
output_features = [category_feature(vocab_size=5, reduce_input="sum")]
Expand All @@ -145,6 +148,61 @@ def test_dask_known_divisions(feature_fn, csv_filename, tmpdir):
)


@pytest.mark.parametrize("generate_images_as_numpy", [False, True])
def test_read_image_from_path(tmpdir, csv_filename, generate_images_as_numpy):
input_features = [image_feature(os.path.join(tmpdir, "generated_output"), save_as_numpy=generate_images_as_numpy)]
output_features = [category_feature(vocab_size=5, reduce_input="sum")]
data_csv = generate_data(
input_features, output_features, os.path.join(tmpdir, csv_filename), num_examples=NUM_EXAMPLES
)

config = {
"input_features": input_features,
"output_features": output_features,
"trainer": {"epochs": 2},
}

model = LudwigModel(config)
model.preprocess(
data_csv,
skip_save_processed_input=False,
)


def test_read_image_from_numpy_array(tmpdir, csv_filename):
input_features = [image_feature(os.path.join(tmpdir, "generated_output"))]
output_features = [category_feature(vocab_size=5, reduce_input="sum")]

config = {
"input_features": input_features,
"output_features": output_features,
TRAINER: {"epochs": 2},
}

data_csv = generate_data(
input_features, output_features, os.path.join(tmpdir, csv_filename), num_examples=NUM_EXAMPLES
)

df = pd.read_csv(data_csv)
processed_df_rows = []

for _, row in df.iterrows():
processed_df_rows.append(
{
input_features[0][NAME]: np.array(Image.open(row[input_features[0][NAME]])),
output_features[0][NAME]: row[output_features[0][NAME]],
}
)

df_with_images_as_numpy_arrays = pd.DataFrame(processed_df_rows)

model = LudwigModel(config)
model.preprocess(
df_with_images_as_numpy_arrays,
skip_save_processed_input=False,
)


def test_number_feature_wrong_dtype(csv_filename, tmpdir):
"""Tests that a number feature with all string values is treated as having missing values by default."""
data_csv_path = os.path.join(tmpdir, csv_filename)
Expand Down