/*
 * Copyright (C) 2023, MediaTek Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Generic secure firmware download
 */

#include <errno.h>
#include <stdbool.h>
#include <common/bl_common.h>
#include <common/tf_crc32.h>
#include <lib/xlat_tables/xlat_tables_v2.h>
#include <drivers/auth/mbedtls/mbedtls_common.h>
#include <mbedtls/memory_buffer_alloc.h>
#include <mbedtls/sha256.h>
#include <mbedtls/rsa.h>
#include <platform_def.h>
#include "fwdl-internal.h"

#define DRAM_BASE			0x40000000ULL
#define DRAM_MAX_SIZE			0x200000000ULL
#define MAX_FW_SIZE			0x200000

static uint8_t mbedtls_heap[TF_MBEDTLS_HEAP_SIZE];

static struct fw_image_register *fwimg_regs[] = {
	&tops_fwimg_reg,
	&wo_fwimg_reg,
};

static int fw_signature_verify_rsa(const void *fwdata, size_t fwsize,
				   size_t total_size, const uint8_t *keyhash)
{
	const struct fw_rsa_pubkey *rsapk = fwdata + total_size;
	const struct fw_base_header *fwhdr = fwdata;
	const void *sig_hash, *sig_data;
	mbedtls_rsa_context rctx;
	mbedtls_mpi n, e;
	uint8_t hash[32];
	uint32_t sb_len;
	int ret;

	sb_len = total_size + sizeof(struct fw_rsa_pubkey);
	if (sb_len != fwhdr->sign_body_len) {
		FWDL_E("FWDL: Firmware RSA message body length mismatch\n");
		return -EBADMSG;
	}

	if (sb_len + RSA_MD_HASH_LEN + RSA_SIG_LEN > fwsize) {
		FWDL_E("FWDL: Incomplete firmware RSA signature\n");
		return -EBADMSG;
	}

	mbedtls_memory_buffer_alloc_init(mbedtls_heap, sizeof(mbedtls_heap));

	if (keyhash) {
		ret = mbedtls_sha256_ret((const unsigned char *)rsapk,
					 sizeof(*rsapk), hash, 0);
		if (ret) {
			FWDL_E("FWDL: Failed to calculate SHA256 of RSA pubkey\n");
			goto cleanup;
		}

		if (memcmp(keyhash, hash, sizeof(hash))) {
			FWDL_E("FWDL: RSA pubkey hash mismatch\n");
			ret = -EBADMSG;
			goto cleanup;
		}
	}

	sig_hash = fwdata + sb_len;
	sig_data = sig_hash + RSA_MD_HASH_LEN;

	ret = mbedtls_sha256_ret(fwdata, sb_len, hash, 0);
	if (ret) {
		FWDL_E("FWDL: Failed to calculate SHA256 of signature body\n");
		goto cleanup;
	}

	if (memcmp(sig_hash, hash, sizeof(hash))) {
		FWDL_E("FWDL: Signature body hash mismatch\n");
		ret = -EBADMSG;
		goto cleanup;
	}

	mbedtls_mpi_init(&n);
	mbedtls_mpi_init(&e);

	ret = mbedtls_mpi_read_binary(&n, rsapk->modulus, rsapk->modulus_len);
	if (ret) {
		FWDL_E("FWDL: Failed to import RSA key modulus\n");
		goto cleanup;
	}

	ret = mbedtls_mpi_read_binary(&e, rsapk->exponent, rsapk->exponent_len);
	if (ret) {
		FWDL_E("FWDL: Failed to import RSA key exponent\n");
		mbedtls_mpi_free(&n);
		goto cleanup;
	}

	mbedtls_rsa_init(&rctx, MBEDTLS_RSA_PKCS_V21, MBEDTLS_MD_SHA256);

	ret = mbedtls_rsa_import(&rctx, &n, NULL, NULL, NULL, &e);
	if (ret) {
		FWDL_E("FWDL: Failed to import RSA public key\n");
		mbedtls_mpi_free(&n);
		mbedtls_mpi_free(&e);
		goto rsa_cleanup;
	}

	ret = mbedtls_rsa_complete(&rctx);
	if (ret) {
		FWDL_E("FWDL: RSA public key is unusable\n");
		goto rsa_cleanup;
	}

	ret = mbedtls_rsa_rsassa_pss_verify(&rctx, NULL, NULL,
					    MBEDTLS_RSA_PUBLIC,
					    MBEDTLS_MD_SHA256, 0,
					    sig_hash, sig_data);
	if (ret)
		FWDL_E("FWDL: RSA signature verification failed\n");

rsa_cleanup:
	mbedtls_rsa_free(&rctx);

cleanup:
	mbedtls_memory_buffer_alloc_free();
	memset(hash, 0, sizeof(hash));

	return ret;
}

static int fw_signature_verify(const void *fwdata, size_t fwsize,
			       size_t total_size, const uint8_t *keyhash)
{
	const struct fw_base_header *fwhdr = fwdata;
	int ret;

	switch (fwhdr->signing_type) {
	case FW_SIG_NONE:
#ifdef FWDL_ALLOW_NO_SIGNING
		/* XXX: disable this type for release */
		FWDL_N("FWDL: Firmware is not signed\n");
		return 0;
#else
		FWDL_E("FWDL: Firmware without signing is not allowed\n");
		return -ENOTSUP;
#endif

	case FW_SIG_RSA2048_SHA256_PSS:
		ret = fw_signature_verify_rsa(fwdata, fwsize, total_size,
					      keyhash);
		break;

	default:
		FWDL_E("FWDL: Unsupported firmware signing type\n");
		return -ENOTSUP;
	}

	if (!ret)
		FWDL_N("FWDL: Firmware signature verification passed\n");

	return ret;
}

static enum fwdl_status fw_validate(const void *fwdata, size_t fwsize,
				    size_t *ret_total_size)
{
	const struct fw_base_header *fwhdr = fwdata;
	const struct fw_part_header *phdr;
	uint32_t crc, ph_len, total_size;
	struct fw_base_header crchdr;
	const uint8_t *pldata;

	if (fwhdr->hdr_len != sizeof(struct fw_base_header)) {
		FWDL_E("FWDL: Firmware header length mismatch\n");
		return FWDL_STATUS_INVALID_FW_FORMAT;
	}

	memcpy(&crchdr, fwhdr, sizeof(struct fw_base_header));

	crchdr.hdr_crc = 0;
	crc = tf_crc32(0, (uint8_t *)&crchdr, sizeof(crchdr));

	if (crc != fwhdr->hdr_crc) {
		FWDL_E("FWDL: Firmware header checksum mismatch\n");
		return FWDL_STATUS_FW_VALIDATION_FAILED;
	}

	if (fwhdr->part_header_len != sizeof(struct fw_part_header)) {
		FWDL_E("FWDL: Firmware part header length mismatch\n");
		return FWDL_STATUS_INVALID_FW_FORMAT;
	}

	if (fwhdr->hdr_ver > FW_HDR_VER) {
		FWDL_E("FWDL: Unsupported firmware header version\n");
		return FWDL_STATUS_UNSUPPORTED_FW_HDR_VER;
	}

	if (fwhdr->signing_type >= __FW_SIG_MAX) {
		FWDL_E("FWDL: Unsupported firmware signing type\n");
		return FWDL_STATUS_FW_VALIDATION_FAILED;
	}

	ph_len = fwhdr->num_parts * fwhdr->part_header_len;
	total_size = fwhdr->hdr_len + ph_len + fwhdr->payload_len;

	if (total_size > fwsize) {
		FWDL_E("FWDL: Firmware is incomplete\n");
		return FWDL_STATUS_INVALID_SIZE;
	}

	phdr = fwdata + fwhdr->hdr_len;
	crc = tf_crc32(0, (const uint8_t *)phdr, ph_len);

	if (crc != fwhdr->part_hdr_crc) {
		FWDL_E("FWDL: Firmware part header checksum mismatch\n");
		return FWDL_STATUS_FW_VALIDATION_FAILED;
	}

	pldata = fwdata + fwhdr->hdr_len + ph_len;
	crc = tf_crc32(0, pldata, fwhdr->payload_len);

	if (crc != fwhdr->payload_crc) {
		FWDL_E("FWDL: Firmware payload checksum mismatch\n");
		return FWDL_STATUS_FW_VALIDATION_FAILED;
	}

	*ret_total_size = total_size;

	return FWDL_STATUS_OK;
}

static enum fwdl_status find_fw_reg(const void *fwdata,
				    struct fw_image_register **retfwir)
{
	const struct fw_base_header *fwhdr = fwdata;
	bool fw_magic_match = false, fw_plat_match = false, curr_fw_plat_match;
	struct fw_image_register *fwir = NULL;
	uint32_t i, j;

	for (i = 0; i < ARRAY_SIZE(fwimg_regs); i++) {
		if (fwhdr->magic != fwimg_regs[i]->magic)
			continue;

		fw_magic_match = true;
		curr_fw_plat_match = false;

		for (j = 0; j < fwimg_regs[i]->num_plats; j++) {
			if (fwhdr->plat_id == fwimg_regs[i]->plat_ids[j]) {
				fw_plat_match = true;
				curr_fw_plat_match = true;
				break;
			}
		}

		if (!curr_fw_plat_match)
			continue;

		for (j = 0; j < fwimg_regs[i]->num_roles; j++) {
			if (fwhdr->role == fwimg_regs[i]->role_ids[j]) {
				fwir = fwimg_regs[i];
				break;
			}
		}

		if (fwir)
			break;
	}

	if (!fw_magic_match) {
		FWDL_E("FWDL: Unsupported firmware magic\n");
		return FWDL_STATUS_INVALID_FW_FORMAT;
	}

	if (!fw_plat_match) {
		FWDL_E("FWDL: Unsupported firmware platform\n");
		return FWDL_STATUS_INVALID_FW_PLATFORM;
	}

	if (!fwir) {
		FWDL_E("FWDL: Unsupported firmware role\n");
		return FWDL_STATUS_INVALID_FW_ROLE;
	}

	*retfwir = fwir;

	return FWDL_STATUS_OK;
}

static int fw_find_part_reg(struct fw_image_register *fwir, uint8_t part_type)
{
	uint32_t i;

	for (i = 0; i < fwir->num_parts; i++) {
		if (fwir->required_part_types[i] == part_type)
			return i;
	}

	return -1;
}

static enum fwdl_status fw_load_real(const void *fwdata, uint64_t flags,
				     struct fw_image_register *fwir)
{
	const struct fw_base_header *fwhdr = fwdata;
	const struct fw_part_header *phdr = fwdata + fwhdr->hdr_len;
	uint32_t i, j, part_off;
	enum fwdl_status status;
	int pr_idx;

	spin_lock(&fwir->load_lock);

	memset(fwir->part_data, 0, fwir->num_parts * sizeof(*fwir->part_data));

	part_off = fwhdr->hdr_len + fwhdr->num_parts * fwhdr->part_header_len;
	FWDL_N("FWDL: Part data start at 0x%x\n", part_off);

	for (i = 0; i < fwhdr->num_parts; i++) {
		FWDL_N("FWDL: Part %u data at 0x%x, size 0x%x\n",
		       phdr[i].type, part_off, phdr[i].data_size);

		pr_idx = fw_find_part_reg(fwir, phdr[i].type);
		if (pr_idx < 0) {
			FWDL_N("FWDL: Part %u is not registered to use\n",
			       phdr[i].type);
			goto next_part;
		}

		fwir->part_data[pr_idx].data = fwdata + part_off;
		fwir->part_data[pr_idx].size = phdr[i].data_size;
		fwir->part_data[pr_idx].flags = phdr[i].flags;

		for (j = 0; j < ARRAY_SIZE(phdr[i].value); j++)
			fwir->part_data[pr_idx].value[j] = phdr[i].value[j];

	next_part:
		part_off += ((phdr[i].data_size + PAYLOAD_ALIGNMENT - 1)
			& ~(PAYLOAD_ALIGNMENT - 1));
	}

	status = fwir->do_fw_load(fwir, fwhdr->plat_id, fwhdr->role, flags);

	spin_unlock(&fwir->load_lock);

	return status;
}

uint64_t fw_load(uint64_t flags, uint64_t addr, uint64_t size)
{
	enum fwdl_status status = FWDL_STATUS_INVALID_PARAM;
	size_t fw_size_align, total_size;
	struct fw_image_register *fwir;
	uintptr_t fw_base;
	int ret;

	if (addr < DRAM_BASE || addr + size >= DRAM_BASE + DRAM_MAX_SIZE ||
	    addr % PAGE_SIZE)
		return FWDL_STATUS_INVALID_ADDRESS;

	if (size > MAX_FW_SIZE || size < sizeof(struct fw_base_header))
		return FWDL_STATUS_INVALID_SIZE;

	fw_size_align = page_align(size, UP);

	ret = mmap_add_dynamic_region_alloc_va(addr, &fw_base, fw_size_align,
					       MT_MEMORY | MT_RW | MT_NS);
	if (ret)
		return FWDL_STATUS_FW_MMAP_FAILED;

#ifdef FWDL_DEBUG
	console_switch_state(CONSOLE_FLAG_BOOT);
#endif

	status = fw_validate((void *)fw_base, size, &total_size);
	if (status)
		goto cleanup;

	FWDL_N("FWDL: Firmware data validation passed\n");

	status = find_fw_reg((void *)fw_base, &fwir);
	if (status)
		goto cleanup;

	ret = fw_signature_verify((void *)fw_base, size, total_size,
				  fwir->keyhash);
	if (ret) {
		status = FWDL_STATUS_FW_VALIDATION_FAILED;
		goto cleanup;
	}

	status = fw_load_real((void *)fw_base, flags, fwir);

	FWDL_N("FWDL: Firmware load result: %u\n", status);

cleanup:
	mmap_remove_dynamic_region(fw_base, fw_size_align);

#ifdef FWDL_DEBUG
	console_switch_state(CONSOLE_FLAG_RUNTIME);
#endif

	return status;
}
