// SPDX-License-Identifier: BSD-3-Clause
/*
 * Copyright (c) 2024, MediaTek Inc. All rights reserved.
 *
 * Author: Weijie Gao <weijie.gao@mediatek.com>
 */

#include <assert.h>
#include <errno.h>
#include <string.h>
#include <arch_helpers.h>
#include <common/debug.h>
#include <common/bl_common.h>
#include <common/tbbr/tbbr_img_def.h>
#include <common/tf_crc32.h>
#include <drivers/io/io_storage.h>
#include <drivers/mmc.h>
#include <plat/common/platform.h>
#include <platform_def.h>
#include <tools_share/firmware_image_package.h>
#include "bl2_plat_setup.h"
#include "dual_fip.h"

/* Dual-FIP result in DRAM, just above BL33 */
#define DUAL_FIP_RESULT_ADDR		(BL33_BASE - DUAL_FIP_RESULT_LEN)

static struct dual_fip_result dfipr;

static inline int compare_uuids(const uuid_t *uuid1, const uuid_t *uuid2)
{
	return memcmp(uuid1, uuid2, sizeof(uuid_t));
}

void dual_fip_init(void)
{
	memset(&dfipr, 0, DUAL_FIP_RESULT_LEN);
}

void set_fip_dev_invalid(bool fip2)
{
	if (fip2)
		dfipr.fip_state[1] = DUAL_FIP_STATE_NODEV;
	else
		dfipr.fip_state[0] = DUAL_FIP_STATE_NODEV;
}

/*
 * Copied from bl2/bl2_image_load_v2.c
 * We uses loading image as checking integrity
 */
static bool load_fip_images(void)
{
	const bl_load_info_node_t *bl2_node_info;
	bl_load_info_t *bl2_load_info;
	int err;

	/*
	 * Get information about the images to load.
	 */
	bl2_load_info = plat_get_bl_image_load_info();
	assert(bl2_load_info != NULL);
	assert(bl2_load_info->head != NULL);
	assert(bl2_load_info->h.type == PARAM_BL_LOAD_INFO);
	assert(bl2_load_info->h.version >= VERSION_2);
	bl2_node_info = bl2_load_info->head;

	while (bl2_node_info != NULL) {
		err = bl2_plat_handle_pre_image_load(bl2_node_info->image_id);
		if (err != 0) {
			ERROR("BL2: Failure in pre image load handling (%i)\n",
			      err);
			return false;
		}

		if ((bl2_node_info->image_info->h.attr &
		    IMAGE_ATTRIB_SKIP_LOADING) == 0U) {
			INFO("BL2: Loading image id %u for validation\n",
			     bl2_node_info->image_id);
			err = load_auth_image(bl2_node_info->image_id,
				bl2_node_info->image_info);
			if (err != 0) {
				ERROR("BL2: Failed to load image id %u (%i) for validation\n",
				      bl2_node_info->image_id, err);
				return false;
			}
		}

		/* Allow platform to handle image information. */
		err = bl2_plat_handle_post_image_load(bl2_node_info->image_id);
		if (err != 0) {
			ERROR("BL2: Failure in post image load handling (%i)\n",
			      err);
			return false;
		}

		/* Go to next image */
		bl2_node_info = bl2_node_info->next_load_info;
	}

	return true;
}

static int read_fip(uintptr_t image_handle, uintptr_t addr, size_t maxsize,
		    size_t *retsize)
{
	size_t bytes_read, len, fip_size = 0;
	fip_toc_header_t header;
	fip_toc_entry_t entry;
	int ret;

	static const uuid_t uuid_null = { { 0 } }; /* Double braces for clang */

	/* Check header */
	ret = io_read(image_handle, (uintptr_t)&header, sizeof(header),
		      &bytes_read);
	if (ret) {
		ERROR("Failed to read FIP header (%d)\n", ret);
		return ret;
	}

	if ((header.name != TOC_HEADER_NAME) || !header.serial_number) {
		ERROR("Not a valid FIP header\n");
		return -EINVAL;
	}

	/* Check entries */
	do {
		ret = io_read(image_handle, (uintptr_t)&entry, sizeof(entry),
			      &bytes_read);
		if (ret) {
			ERROR("Failed to read FIP entry (%d)\n", ret);
			return ret;
		}

		if (!compare_uuids(&entry.uuid, &uuid_null))
			break;

		len = entry.offset_address + entry.size;
		if (len < entry.offset_address || len < entry.size) {
			ERROR("Entry offset and size overflow\n");
			return -EINVAL;
		}

		if (len > fip_size)
			fip_size = len;
	} while (fip_size <= maxsize);

	if (fip_size > maxsize) {
		ERROR("FIP size is too large\n");
		return -E2BIG;
	}

	NOTICE("FIP data size: %zu\n", fip_size);

	/* Now read FIP data */
	ret = io_seek(image_handle, IO_SEEK_SET, 0);
	if (ret) {
		ERROR("Unable to seek to start of FIP\n");
		return ret;
	}

	len = 0;

	while (len < fip_size) {
		ret = io_read(image_handle, addr, fip_size - len, &bytes_read);
		if (ret) {
			ERROR("Failed to read FIP data (%d)\n", ret);
			return ret;
		}

		addr += bytes_read;
		len += bytes_read;
	}

	*retsize = fip_size;

	return 0;
}

static int write_data(uintptr_t image_handle, uintptr_t data, size_t size)
{
	size_t bytes_written, len = 0;
	int ret;

	while (len < size) {
		ret = io_write(image_handle, data, size - len, &bytes_written);
		if (ret) {
			ERROR("Failed to write FIP data (%d)\n", ret);
			return ret;
		}

		data += bytes_written;
		len += bytes_written;
	}

	return 0;
}

static int read_data(uintptr_t image_handle, uintptr_t data, size_t size)
{
	size_t bytes_read, len = 0;
	int ret;

	while (len < size) {
		ret = io_read(image_handle, data, size - len, &bytes_read);
		if (ret) {
			ERROR("Failed to read FIP data (%d)\n", ret);
			return ret;
		}

		data += bytes_read;
		len += bytes_read;
	}

	return 0;
}

static bool resotre_fip(bool to_fip2)
{
	uintptr_t fip_src_dev_handle, fip_dst_dev_handle;
	uintptr_t fip_src_image_spec, fip_dst_image_spec;
	uintptr_t fip_src_image_handle, fip_dst_image_handle;
	io_block_spec_t *bspec;
	const uint8_t *p1, *p2;
	size_t i, size, wrsz;
	int ret;

	if (to_fip2) {
		ret = plat_get_image_source(FIP_IMAGE_ID, &fip_src_dev_handle,
					    &fip_src_image_spec);
		if (ret) {
			ERROR("Failed to obtain reference to FIP1 image as source (%d)\n",
			       ret);
			return false;
		}

		ret = plat_get_image_source(FIP2_IMAGE_ID, &fip_dst_dev_handle,
					    &fip_dst_image_spec);
		if (ret) {
			ERROR("Failed to obtain reference to FIP2 image as destination (%d)\n",
			       ret);
			return false;
		}
	} else {
		ret = plat_get_image_source(FIP_IMAGE_ID, &fip_dst_dev_handle,
					    &fip_dst_image_spec);
		if (ret) {
			ERROR("Failed to obtain reference to FIP1 image as destination (%d)\n",
			       ret);
			return false;
		}

		ret = plat_get_image_source(FIP2_IMAGE_ID, &fip_src_dev_handle,
					    &fip_src_image_spec);
		if (ret) {
			ERROR("Failed to obtain reference to FIP2 image as source (%d)\n",
			       ret);
			return false;
		}
	}

	ret = io_open(fip_src_dev_handle, fip_src_image_spec, &fip_src_image_handle);
	if (ret) {
		ERROR("Failed to access source FIP image (%d)\n", ret);
		return false;
	}

	bspec = (io_block_spec_t *)fip_src_image_spec;

	/* Read to BL33 base */
	ret = read_fip(fip_src_image_handle, BL33_BASE, bspec->length, &size);
	if (ret) {
		ERROR("Failed to read source FIP image data (%d)\n", ret);
		io_close(fip_src_image_handle);
		return false;
	}

	NOTICE("Successfully read source FIP image data\n");

	io_close(fip_src_image_handle);

	ret = io_open(fip_dst_dev_handle, fip_dst_image_spec, &fip_dst_image_handle);
	if (ret) {
		ERROR("Failed to access destination FIP image (%d)\n", ret);
		return false;
	}

	/* Align data size */
	wrsz = ((size + MMC_BLOCK_SIZE * 2 - 1) & ~(MMC_BLOCK_SIZE - 1));

	ret = write_data(fip_dst_image_handle, BL33_BASE, wrsz);
	if (ret) {
		ERROR("Failed to write destination FIP image data (%d)\n", ret);
		io_close(fip_dst_image_handle);
		return false;
	}

	NOTICE("Successfully written destination FIP image data\n");

	ret = io_seek(fip_dst_image_handle, IO_SEEK_SET, 0);
	if (ret) {
		ERROR("Unable to seek to start of destination FIP\n");
		io_close(fip_dst_image_handle);
		return false;
	}

	ret = read_data(fip_dst_image_handle, BL33_BASE + size, size);
	if (ret) {
		ERROR("Failed to read destination FIP image data for verification (%d)\n", ret);
		io_close(fip_dst_image_handle);
		return false;
	}

	io_close(fip_dst_image_handle);

	/* Verify data */
	p1 = (const uint8_t *)BL33_BASE;
	p2 = (const uint8_t *)(BL33_BASE + size);

	for (i = 0; i < size; i++) {
		if (p1[i] != p2[i]) {
			ERROR("FIP data verification failed at 0x%zx, expected 0x%02x, got 0x%02x\n",
			      i, p1[i], p2[i]);
			return false;
		}
	}

	NOTICE("Destination FIP image data has been verified\n");

	return true;
}

bool check_dual_fip(void)
{
	NOTICE("Starting Dual-FIP checking ...\n");

	if (dfipr.fip_state[0] == DUAL_FIP_STATE_NODEV) {
		WARN("FIP1 is unavailable\n");
	} else {
		NOTICE("Checking integrity of FIP1 ...\n");
		mtk_bl2_set_fip_image_id(false);

		if (!load_fip_images()) {
			WARN("Bad integrity of FIP1\n");
			dfipr.fip_state[0] = DUAL_FIP_STATE_WAS_BAD;
		} else {
			dfipr.fip_state[0] = DUAL_FIP_STATE_OK;
			NOTICE("FIP1 is good\n");
		}
	}

	if (dfipr.fip_state[1] == DUAL_FIP_STATE_NODEV) {
		WARN("FIP2 is unavailable\n");
	} else {
		NOTICE("Checking integrity of FIP2 ...\n");
		mtk_bl2_set_fip_image_id(true);

		if (!load_fip_images()) {
			WARN("Bad integrity of FIP2\n");
			dfipr.fip_state[1] = DUAL_FIP_STATE_WAS_BAD;
		} else {
			dfipr.fip_state[1] = DUAL_FIP_STATE_OK;
			NOTICE("FIP2 is good\n");
		}
	}

	if (dfipr.fip_state[0] != DUAL_FIP_STATE_OK &&
	    dfipr.fip_state[1] != DUAL_FIP_STATE_OK) {
		ERROR("Both FIPs are bad. Unable to boot.\n");
		panic();
	}

	/* Assemble FIP result */
	dfipr.ver = DUAL_FIP_VER;
	dfipr.crc = tf_crc32(0, (void *)&dfipr, DUAL_FIP_RESULT_LEN);

	/* Restore broken FIP if possible */
	if (dfipr.fip_state[0] == DUAL_FIP_STATE_WAS_BAD &&
	    dfipr.fip_state[1] == DUAL_FIP_STATE_OK) {
		NOTICE("Restoring FIP1 ...\n");
		if (resotre_fip(false))
			NOTICE("FIP1 has been updated\n");
	}

	if (dfipr.fip_state[1] == DUAL_FIP_STATE_WAS_BAD &&
	    dfipr.fip_state[0] == DUAL_FIP_STATE_OK) {
		NOTICE("Restoring FIP2 ...\n");
		if (resotre_fip(true))
			NOTICE("FIP2 has been updated\n");
	}

	/* Prompt */
	if (dfipr.fip_state[0] != DUAL_FIP_STATE_OK) {
		NOTICE("Using FIP2 for booting\n");
		mtk_bl2_set_fip_image_id(true);
	} else {
		NOTICE("Using FIP1 for booting\n");
		mtk_bl2_set_fip_image_id(false);
	}

	return true;
}

void flush_dual_fip_result(void)
{
	memcpy((void *)DUAL_FIP_RESULT_ADDR, &dfipr, DUAL_FIP_RESULT_LEN);

	flush_dcache_range((uintptr_t)DUAL_FIP_RESULT_ADDR,
			   DUAL_FIP_RESULT_LEN);
}
