import math import os import hashlib from urllib.request import urlretrieve import zipfile import gzip import shutil import numpy as np from PIL import Image from tqdm import tqdm def _read32(bytestream): """ Read 32-bit integer from bytesteam :param bytestream: A bytestream :return: 32-bit integer """ dt = np.dtype(np.uint32).newbyteorder('>') return np.frombuffer(bytestream.read(4), dtype=dt)[0] def _unzip(save_path, _, database_name, data_path): """ Unzip wrapper with the same interface as _ungzip :param save_path: The path of the gzip files :param database_name: Name of database :param data_path: Path to extract to :param _: HACK - Used to have to same interface as _ungzip """ print('Extracting {}...'.format(database_name)) with zipfile.ZipFile(save_path) as zf: zf.extractall(data_path) def _ungzip(save_path, extract_path, database_name, _): """ Unzip a gzip file and extract it to extract_path :param save_path: The path of the gzip files :param extract_path: The location to extract the data to :param database_name: Name of database :param _: HACK - Used to have to same interface as _unzip """ # Get data from save_path with open(save_path, 'rb') as f: with gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) if magic != 2051: raise ValueError('Invalid magic number {} in file: {}'.format(magic, f.name)) num_images = _read32(bytestream) rows = _read32(bytestream) cols = _read32(bytestream) buf = bytestream.read(rows * cols * num_images) data = np.frombuffer(buf, dtype=np.uint8) data = data.reshape(num_images, rows, cols) # Save data to extract_path for image_i, image in enumerate( tqdm(data, unit='File', unit_scale=True, miniters=1, desc='Extracting {}'.format(database_name))): Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i))) def get_image(image_path, width, height, mode): """ Read image from image_path :param image_path: Path of image :param width: Width of image :param height: Height of image :param mode: Mode of image :return: Image data """ image = Image.open(image_path) if image.size != (width, height): # HACK - Check if image is from the CELEBA dataset # Remove most pixels that aren't part of a face face_width = face_height = 108 j = (image.size[0] - face_width) // 2 i = (image.size[1] - face_height) // 2 image = image.crop([j, i, j + face_width, i + face_height]) image = image.resize([width, height], Image.BILINEAR) return np.array(image.convert(mode)) def get_batch(image_files, width, height, mode): data_batch = np.array( [get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32) # Make sure the images are in 4 dimensions if len(data_batch.shape) < 4: data_batch = data_batch.reshape(data_batch.shape + (1,)) return data_batch def images_square_grid(images, mode): """ Save images as a square grid :param images: Images to be used for the grid :param mode: The mode to use for images :return: Image of images in a square grid """ # Get maximum size for square grid of images save_size = math.floor(np.sqrt(images.shape[0])) # Scale to 0-255 images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8) # Put images in a square arrangement images_in_square = np.reshape( images[:save_size*save_size], (save_size, save_size, images.shape[1], images.shape[2], images.shape[3])) if mode == 'L': images_in_square = np.squeeze(images_in_square, 4) # Combine images to grid image new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size)) for col_i, col_images in enumerate(images_in_square): for image_i, image in enumerate(col_images): im = Image.fromarray(image, mode) new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2])) return new_im def download_extract(database_name, data_path): """ Download and extract database :param database_name: Database name """ DATASET_CELEBA_NAME = 'celeba' DATASET_MNIST_NAME = 'mnist' if database_name == DATASET_CELEBA_NAME: url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip' hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb' extract_path = os.path.join(data_path, 'img_align_celeba') save_path = os.path.join(data_path, 'celeba.zip') extract_fn = _unzip elif database_name == DATASET_MNIST_NAME: url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz' hash_code = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' extract_path = os.path.join(data_path, 'mnist') save_path = os.path.join(data_path, 'train-images-idx3-ubyte.gz') extract_fn = _ungzip if os.path.exists(extract_path): print('Found {} Data'.format(database_name)) return if not os.path.exists(data_path): os.makedirs(data_path) if not os.path.exists(save_path): with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar: urlretrieve( url, save_path, pbar.hook) assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \ '{} file is corrupted. Remove the file and try again.'.format(save_path) os.makedirs(extract_path) try: extract_fn(save_path, extract_path, database_name, data_path) except Exception as err: shutil.rmtree(extract_path) # Remove extraction folder if there is an error raise err # Remove compressed data os.remove(save_path) class Dataset(object): """ Dataset """ def __init__(self, dataset_name, data_files): """ Initalize the class :param dataset_name: Database name :param data_files: List of files in the database """ DATASET_CELEBA_NAME = 'celeba' DATASET_MNIST_NAME = 'mnist' IMAGE_WIDTH = 28 IMAGE_HEIGHT = 28 if dataset_name == DATASET_CELEBA_NAME: self.image_mode = 'RGB' image_channels = 3 elif dataset_name == DATASET_MNIST_NAME: self.image_mode = 'L' image_channels = 1 self.data_files = data_files self.shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, image_channels def get_batches(self, batch_size): """ Generate batches :param batch_size: Batch Size :return: Batches of data """ IMAGE_MAX_VALUE = 255 current_index = 0 while current_index + batch_size <= self.shape[0]: data_batch = get_batch( self.data_files[current_index:current_index + batch_size], *self.shape[1:3], self.image_mode) current_index += batch_size yield data_batch / IMAGE_MAX_VALUE - 0.5 class DLProgress(tqdm): """ Handle Progress Bar while Downloading """ last_block = 0 def hook(self, block_num=1, block_size=1, total_size=None): """ A hook function that will be called once on establishment of the network connection and once after each block read thereafter. :param block_num: A count of blocks transferred so far :param block_size: Block size in bytes :param total_size: The total size of the file. This may be -1 on older FTP servers which do not return a file size in response to a retrieval request. """ self.total = total_size self.update((block_num - self.last_block) * block_size) self.last_block = block_num