#include "stdafx.h"
#include "Ico.h"
#include "Png.h"
#include "Core/Io/MemStream.h"
#include "Utils/Bitwise.h"
#include "Exception.h"

namespace graphics {

	// Fill a structure with data from a stream.
	template <class T>
	static bool fill(IStream *src, T &out) {
		GcPreArray<byte, sizeof(T)> data;
		Buffer r = src->fill(emptyBuffer(data));
		if (r.filled() != sizeof(T))
			return false;
		memcpy(&out, r.dataPtr(), sizeof(T));
		return true;
	}

	// Read data to an array.
	template <class T>
	static T *read(IStream *src, Nat count, Nat &position) {
		Nat size = count*sizeof(T);
		Buffer r = src->fill(buffer(src->engine(), size));
		if (r.filled() != size)
			return null;
		position += size;
		return (T *)r.dataPtr();
	}

	static Bool CODECALL icoApplicable(IStream *from) {
		Buffer buffer = from->peek(storm::buffer(from->engine(), 4));
		if (!buffer.full())
			return false;

		// Signature: reserved (2 bytes), resource type (2 bytes, 1 for icons)
		return buffer[0] == 0x00
			&& buffer[1] == 0x00
			&& buffer[2] == 0x01
			&& buffer[3] == 0x00;
	}

	static FormatOptions *CODECALL icoCreate(ImageFormat *f) {
		return new (f) ICOOptions();
	}

	ICOOptions::ICOOptions() : bestQualityOnly(true) {}

	void ICOOptions::toS(StrBuf *to) const {
		*to << S("ICO {");
		if (bestQualityOnly) {
			*to << S("load only best quality");
		} else {
			*to << S("load all variants");
		}
		*to << S("}");
	}

	Image *ICOOptions::load(IStream *from) {
		ImageSet *s = loadSet(from);
		return s->at(s->count() - 1);
	}

	/**
	 * ICO directory.
	 */
	struct ICODirectory {
		// Reserved, must be 0.
		nat16 reserved;

		// Resource type, 1 for icons.
		nat16 resType;

		// Number of images.
		nat16 imageCount;
	};

	/**
	 * A single entry in the directory.
	 */
	struct ICOEntry {
		// Dimension.
		byte width;
		byte height;

		// Number of colors.
		byte colorCount;

		// Reserved, must be 0.
		byte reserved;

		// Color planes.
		nat16 planes;

		// Bits per pixel.
		nat16 bpp;

		// Bytes in the image.
		nat size;

		// Offset in the file.
		nat offset;
	};

	/**
	 * DIB header in the ICO file.
	 */
	struct DIBHeader {
		// Header size in bytes. May be larger than this header.
		nat size;

		// Image width in pixels.
		nat width;

		// Image height in pixels.
		nat height;

		// Number of planes. Must be 1.
		nat16 planes;

		// Number of bits per pixel. 1, 4, 8, 16, 24, or 32.
		nat16 pixelDepth;

		// Compression type (needs to be 0)
		nat compression;

		// Image size in bytes. Possibly zero for uncompressed images.
		nat imageSize;

		// Resolution in pixels per meter (needs to be 0)
		nat xResolution;
		nat yResolution;

		// Number of color map entries used (needs to be 0)
		nat colorsUsed;

		// Number of significant colors (needs to be 0)
		nat colorsImportant;
	};

	/**
	 * An entry in the color table. Located after ImageHeader.
	 */
	struct DIBColor {
		byte b;
		byte g;
		byte r;
		byte pad;
	};

	/**
	 * A single loaded entry. We use it to remove duplicates later.
	 *
	 * We can't rely on all data in the ICOEntry structure - mostly due to the presence of
	 * PNG-encoded images where height and width are both zero. So, we load all versions first, and
	 * then discard the ones that were duplicates.
	 */
	struct LoadedImage {
		// Loaded image.
		Image *image;

		// Bits per pixel in the actual image.
		Nat bpp;
	};

	// Compare, first on image size (h then v), then bpp:
	bool operator <(const LoadedImage &a, const LoadedImage &b) {
		if (a.image->width() != b.image->width())
			return a.image->width() < b.image->width();
		if (a.image->height() != b.image->height())
			return a.image->height() < b.image->height();
		return a.bpp < b.bpp;
	}

	static LoadedImage loadedImage(Image *image, Nat bpp) {
		LoadedImage x = { image, bpp };
		return x;
	}

	static const GcType loadedImageType = {
		GcType::tArray,
		null,
		null,
		sizeof(LoadedImage),
		1,
		{ OFFSET_OF(LoadedImage, image) }
	};


	static bool compareIcoEntries(const ICOEntry &a, const ICOEntry &b) {
		return a.offset < b.offset;
	}

	static Image *icoDecode1(IStream *from, const DIBHeader &header, Nat &position);
	static Image *icoDecode4(IStream *from, const DIBHeader &header, Nat &position);
	static Image *icoDecode8(IStream *from, const DIBHeader &header, Nat &position);
	static Image *icoDecode24(IStream *from, const DIBHeader &header, Nat &position);
	static Image *icoDecode32(IStream *from, const DIBHeader &header, Nat &position);

	typedef Image *(*IcoDecoder)(IStream *from, const DIBHeader &header, Nat &position);
	static IcoDecoder pickIcoDecoder(const DIBHeader &header);

	static LoadedImage loadIcon(IStream *from, const ICOEntry &entry, Nat &position) {
		if (position > entry.offset)
			throw new (from) ImageLoadError(S("Malformed ICO file: multiple images overlap each other."));

		Nat toSkip = entry.offset - position;
		if (toSkip > 0) {
			from->fill(toSkip);
			position += toSkip;
		}

		// See if this is actually a PNG image.
		if (checkHeader(from, "\x89PNG", false)) {
			// Note: We can't pass a partially seeked stream to the PNG decoder. It needs a random
			// access stream and will seek the stream itself, which will not work. We could adapt
			// the stream, but it is not really worth it since random access is needed anyway.
			Buffer pngData = from->fill(entry.size);
			if (!pngData.full())
				throw new (from) ImageLoadError(S("Not enough data for PNG icon."));
			position += pngData.filled();

			PNGOptions *png = new (from) PNGOptions();
			return loadedImage(png->load(new (from) MemIStream(pngData)), 32);
		}

		// This is (usually) a DIB header. Note that only: size, width, height, planes, bitcount,
		// size are used. Importantly, "compression" is *not* used, so we can't use 'pickDecoder' directly.
		DIBHeader header;
		if (!fill(from, header))
			throw new (from) ImageLoadError(S("Failed to read icon header."));
		position += sizeof(DIBHeader);

		if (header.size < 40 || header.planes != 1)
			throw new (from) ImageLoadError(S("Invalid or incomplete DIB header for icon."));

		IcoDecoder decode = pickIcoDecoder(header);
		if (!decode)
			throw new (from) ImageLoadError(S("Unsupported bit depth in icon."));

		Image *image = (*decode)(from, header, position);
		if (!image)
			throw new (from) ImageLoadError(S("Failed to load icon bitmap."));

		// This 'pixelDepth' is more trustworthy than the BPP in the ICO header, especially since we
		// now know that it is a bitmap image.
		return loadedImage(image, header.pixelDepth);
	}

	ImageSet *ICOOptions::loadSet(IStream *from) {
		const wchar *error = S("");
		Nat position = 0;

		error = S("Invalid or incompatible ICO header.");

		ICODirectory directory;
		if (!fill(from, directory))
			throw new (this) ImageLoadError(error);
		position += sizeof(directory);

		if (directory.reserved != 0 || directory.resType != 1)
			throw new (this) ImageLoadError(error);

		error = S("Invalid ICO directory entry.");

		std::vector<ICOEntry> entries(directory.imageCount, ICOEntry());
		for (size_t i = 0; i < entries.size(); i++) {
			if (!fill(from, entries[i]))
				throw new (this) ImageLoadError(error);
			position += sizeof(ICOEntry);

			if (entries[i].reserved != 0 || entries[i].planes > 1)
				throw new (this) ImageLoadError(error);
		}

		// Sort the entries based on offset, so we can read them in one go.
		std::sort(entries.begin(), entries.end(), compareIcoEntries);

		// Load all images, one by one:
		GcArray<LoadedImage> *loaded = runtime::allocArray<LoadedImage>(engine(), &loadedImageType, entries.size());
		for (size_t i = 0; i < entries.size(); i++) {
			loaded->v[loaded->filled++] = loadIcon(from, entries[i], position);
		}

		// Sort them grouped by image size, then by bpp (we put them in order of increasing quality).
		std::sort(loaded->v, loaded->v + loaded->filled);

		if (bestQualityOnly && loaded->filled > 0) {
			// If asked to only get the best ones, filter out the others. We just walk from the end
			// of the array and set everything that does not have a new resolution to null.
			Nat curWidth = loaded->v[loaded->filled - 1].image->width();
			Nat curHeight = loaded->v[loaded->filled - 1].image->height();

			for (size_t i = loaded->filled - 1; i > 0; i--) {
				Image *&img = loaded->v[i - 1].image;
				if (img->width() != curWidth || img->height() != curHeight) {
					// Keep it, it is different.
					curWidth = img->width();
					curHeight = img->height();
				} else {
					// Same resolution, worse quality.
					img = null;
				}
			}
		}

		// Insert into an ImageSet (this is actually cheaper now, since they are already sorted!)
		ImageSet *images = new (this) ImageSet();
		for (size_t i = 0; i < loaded->filled; i++) {
			if (loaded->v[i].image)
				images->push(loaded->v[i].image);
		}

		return images;
	}

	static IcoDecoder pickIcoDecoder(const DIBHeader &header) {
		switch (header.pixelDepth) {
		case 1:
			return &icoDecode1;
		case 4:
			return &icoDecode4;
		case 8:
			return &icoDecode8;
			// Note the oldnewthing blog says 16 bits exist, but not 24 bits. That is likely the
			// opposite - I found 24-bit icons from old Visual Studio stock icons (but no 16-bit),
			// Gimp does not support encoding 16-bit images, and the docs say that the compression
			// field is not used (which would be needed for the bitfield encoding required).
		case 24:
			return &icoDecode24;
		case 32:
			return &icoDecode32;
		}
		return null;
	}

	static bool applyMask(IStream *from, Image *to, Nat w, Nat h, Nat &position) {
		Nat stride = roundUp((w + 7) / 8, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->read(src);
			position += src.filled();
			if (src.filled() != stride)
				return false;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				byte color = src[x / 8];
				color = (color >> (7 - (x & 0x7))) & 0x1;

				if (color) {
					// Set alpha channel to transparent where we are asked to do so.
					dest[4*x + 3] = 0;
				}
			}
		}
		return true;
	}

	static Image *icoDecode32(IStream *from, const DIBHeader &header, Nat &position) {
		Nat w = header.width;
		Nat h = header.height;

		// These bitmaps are strange. First, we have a color bitmap, then a 1-bpp mask.
		h /= 2;

		Image *to = new (from) Image(w, h);

		// Color part:
		Nat stride = w * 4;
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->read(src);
			position += src.filled();
			if (src.filled() != stride)
				return null;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				dest[4*x + 2] = src[x*4 + 0];
				dest[4*x + 1] = src[x*4 + 1];
				dest[4*x + 0] = src[x*4 + 2];
				dest[4*x + 3] = src[x*4 + 3];
			}
		}

		applyMask(from, to, w, h, position);
		return to;
	}

	static Image *icoDecode24(IStream *from, const DIBHeader &header, Nat &position) {
		Nat w = header.width;
		Nat h = header.height;

		// These bitmaps are strange. First, we have a color bitmap, then a 1-bpp mask.
		h /= 2;

		Image *to = new (from) Image(w, h);

		// Color part:
		Nat stride = roundUp(w * 3, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->read(src);
			position += src.filled();
			if (src.filled() != stride)
				return null;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				dest[4*x + 2] = src[x*3 + 0];
				dest[4*x + 1] = src[x*3 + 1];
				dest[4*x + 0] = src[x*3 + 2];
				dest[4*x + 3] = 255;
			}
		}

		applyMask(from, to, w, h, position);
		return to;
	}

	static Image *icoDecode8(IStream *from, const DIBHeader &header, Nat &position) {
		Nat w = header.width;
		Nat h = header.height;

		// These bitmaps are strange. First, we have a color bitmap, then a 1-bpp mask.
		h /= 2;

		// Read color table.
		Nat used = header.colorsUsed;
		if (used == 0)
			used = 256;
		DIBColor *palette = read<DIBColor>(from, used, position);

		Image *to = new (from) Image(w, h);

		// Image part:
		Nat stride = roundUp(w, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->read(src);
			position += src.filled();
			if (src.filled() != stride)
				return null;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				byte color = src[x];

				dest[4*x + 0] = palette[color].r;
				dest[4*x + 1] = palette[color].g;
				dest[4*x + 2] = palette[color].b;
				dest[4*x + 3] = 255;
			}
		}

		applyMask(from, to, w, h, position);
		return to;
	}

	static Image *icoDecode4(IStream *from, const DIBHeader &header, Nat &position) {
		Nat w = header.width;
		Nat h = header.height;

		// These bitmaps are strange. First, we have a color bitmap, then a 1-bpp mask.
		h /= 2;

		// Read color table.
		Nat used = header.colorsUsed;
		if (used == 0)
			used = 16;
		DIBColor *palette = read<DIBColor>(from, used, position);

		Image *to = new (from) Image(w, h);

		// Image part:
		Nat stride = roundUp((w + 1) / 2, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->read(src);
			position += src.filled();
			if (src.filled() != stride)
				return null;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				byte color = src[x / 2];
				color = (color >> (~x & 0x1)*4) & 0xF;

				dest[4*x + 0] = palette[color].r;
				dest[4*x + 1] = palette[color].g;
				dest[4*x + 2] = palette[color].b;
				dest[4*x + 3] = 255;
			}
		}

		applyMask(from, to, w, h, position);
		return to;
	}

	static Image *icoDecode1(IStream *from, const DIBHeader &header, Nat &position) {
		Nat w = header.width;
		Nat h = header.height;

		// These bitmaps are strange. First, we have a color bitmap, then a 1-bpp mask.
		h /= 2;

		// Read color table.
		Nat used = header.colorsUsed;
		if (used == 0)
			used = 2;
		DIBColor *palette = read<DIBColor>(from, used, position);

		Image *to = new (from) Image(w, h);

		// Image part:
		Nat stride = roundUp((w + 7) / 8, Nat(4));
		Buffer src = buffer(from->engine(), stride);
		for (Nat y = 0; y < h; y++) {
			src.filled(0);
			src = from->read(src);
			position += src.filled();
			if (src.filled() != stride)
				return null;

			byte *dest = to->buffer(0, h - y - 1);
			for (Nat x = 0; x < w; x++) {
				byte color = src[x / 8];
				color = (color >> (7 - (x & 0x7))) & 0x1;

				dest[4*x + 0] = palette[color].r;
				dest[4*x + 1] = palette[color].g;
				dest[4*x + 2] = palette[color].b;
				dest[4*x + 3] = 255;
			}
		}

		applyMask(from, to, w, h, position);
		return to;
	}


	void ICOOptions::save(Image *image, OStream *to) {
		throw new (this) ImageSaveError(S("Can not save ICO files yet."));
	}


	ImageFormat *icoFormat(Engine &e) {
		const wchar *exts[] = {
			S("ico"),
			null
		};
		return new (e) ImageFormat(S("Icon"), exts, &icoApplicable, &icoCreate);
	}

}
