// SPDX-License-Identifier: GPL-2.0+
/*
 * Copyright (C) 2022 MediaTek Incorporation. All Rights Reserved.
 *
 * Author: guan-gm.lin <guan-gm.lin@mediatek.com>
 */

#ifndef USE_HOSTCC
#include <malloc.h>
#include <linux/arm-smccc.h>
#include <linux/iopoll.h>
#include <clk.h>
#include <misc.h>
#include <dm.h>
#include <log.h>
#include <misc.h>
#include <tee.h>
#endif /* ifndef USE_HOSTCC */
#include <image.h>
#include <uboot_aes.h>

#ifndef USE_HOSTCC
#define TA_FIRMWARE_ENCRYPTION_UUID \
	{ 0x503810ea, 0x5f92, 0x49d3, \
		{ 0xa5, 0xf3, 0x87, 0xe9, 0xed, 0x02, 0x76, 0xa9 } }

#define SET_ALGO    			1
#define SET_IV      			2
#define SET_DATA    			3
#define SET_SALT    			4
#define TA_AES_ALGO_CBC                 1
#define TA_AES_SIZE_256BIT              (256 / 8)

static int get_salt(const void *fit, unsigned char **salt, u32 *salt_len)
{
	int image_noffset;

	image_noffset = fdt_path_offset(fit, FIT_IMAGES_PATH);
	if (image_noffset < 0) {
		printf("Can't get found '/images'""\n");
		return -1;
	}

	*salt = (unsigned char*)fdt_getprop(fit, image_noffset, "salt", salt_len);
	if(*salt == NULL) {
		printf("Can't get salt\n");
		return -1;
	}

	return 0;
}

static int session_init(struct udevice *tee, u32 *tee_session)
{
	struct tee_open_session_arg arg;
	struct tee_optee_ta_uuid uuid = TA_FIRMWARE_ENCRYPTION_UUID;
	int res;

	memset(&arg, 0, sizeof(arg));
	tee_optee_ta_uuid_to_octets(arg.uuid, &uuid);
	arg.clnt_login = TEE_LOGIN_PUBLIC;
	res = tee_open_session(tee, &arg, 0, NULL);

	if (res) {
		printf("Failed: open ta, error code %x\n", res);
		return -1;
	}

	*tee_session = arg.session;

	return 0;
}

static int session_deinit(struct udevice *tee, u32 tee_session)
{
	int res;

	res = tee_close_session(tee, tee_session);
	if (res) {
		printf("tee_close_session error, res: %x\n", res);
		return res;
	}

	return res;
}

static int set_iv(struct udevice *tee, u32 tee_session,
		  unsigned char *iv, u32 iv_len)
{
	int res;
	struct tee_invoke_arg arg;
	struct tee_param param[1];
	struct tee_shm *iv_shm = NULL;

	memset(param, 0, sizeof(param));
	memset(&arg, 0, sizeof(arg));

	res = tee_shm_register(tee, iv, iv_len, 0, &iv_shm);
	if(res) {
		printf("iv register share memory failed\n");
		goto out;
	}

	arg.func = SET_IV;
	arg.session = tee_session;

	param[0].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
	param[0].u.memref.shm = iv_shm;
	param[0].u.memref.size = iv_len;

	res = tee_invoke_func(tee, &arg, 1, param);
	if (res || arg.ret) {
		if (res) {
			printf("Failed: salt iv tee_invoke_func error %x\n", res);
			goto out;
		}
		res = arg.ret;
		printf("Failed: set iv error, return %x\n", arg.ret);
		goto out;
	}

out:
	if (iv_shm != NULL)
		tee_shm_free(iv_shm);

	return res;
}

static int set_salt(struct udevice *tee, u32 tee_session,
		    unsigned char *salt, u32 salt_len)
{
	int res;
	struct tee_invoke_arg arg;
	struct tee_param param[1];
	struct tee_shm *salt_shm = NULL;

	memset(param, 0, sizeof(param));
	memset(&arg, 0, sizeof(arg));

	arg.func = SET_SALT;
	arg.session = tee_session;

	res = tee_shm_register(tee, salt, salt_len, 0, &salt_shm);
	if (res) {
		printf("salt register share memory failed\n");
		goto out;
	}

	param[0].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
	param[0].u.memref.shm = salt_shm;
	param[0].u.memref.size = salt_len;

	res = tee_invoke_func(tee, &arg, 1, param);
	if (res || arg.ret) {
		if (res) {
			printf("Failed: set salt tee_invoke_func error %x\n", res);
			goto out;
		}
		res = arg.ret;
		printf("Failed: set salt error, return %x\n", arg.ret);
		goto out;
	}

out:
	if (salt_shm != NULL)
		tee_shm_free(salt_shm);

	return res;
}

static int set_buffer(struct udevice *tee, u32 tee_session,
		      unsigned char *cipher, u32 cipher_len,
		      unsigned char *plain, u32 plain_len)
{
	int res;
	struct tee_invoke_arg arg;
	struct tee_param param[2];
	struct tee_shm *cipher_shm = NULL;
	struct tee_shm *plain_shm = NULL;

	memset(param, 0, sizeof(param));
	memset(&arg, 0, sizeof(arg));

	res = tee_shm_register(tee, cipher, cipher_len, 0, &cipher_shm);
	if(res) {
		printf("cipher data share memory failed\n");
		goto out;
	}

	res = tee_shm_register(tee, plain, cipher_len, 0, &plain_shm);
	if (res) {
		printf("plain data share memory failed\n");
		goto out;
	}

	arg.func = SET_DATA;
	arg.session = tee_session;

	param[0].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
	param[0].u.memref.shm = cipher_shm;
	param[0].u.memref.size = cipher_len;
	param[1].attr = TEE_PARAM_ATTR_TYPE_MEMREF_OUTPUT;
	param[1].u.memref.shm = plain_shm;
	param[1].u.memref.size = cipher_len;

	res = tee_invoke_func(tee, &arg, 2, param);
	if (res || arg.ret) {
		if (res) {
			printf("Failed: set buffer tee_invoke_func error %x\n", res);
			goto out;
		}
		res = arg.ret;
		printf("Failed: set buffer error, return %x\n", arg.ret);
		goto out;
	}

out:
	if (cipher_shm != NULL)
		tee_shm_free(cipher_shm);

	if(plain_shm != NULL)
		tee_shm_free(plain_shm);

	return res;
}

static int set_algo(struct udevice *tee, u32 tee_session) {
	int res;
	struct tee_invoke_arg arg;
	struct tee_param param[2];

	memset(param, 0, sizeof(param));
	memset(&arg, 0, sizeof(arg));

	arg.func = SET_ALGO;
	arg.session = tee_session;

	param[0].attr = TEE_PARAM_ATTR_TYPE_VALUE_INPUT;
	param[0].u.value.a = TA_AES_ALGO_CBC;
	param[1].attr = TEE_PARAM_ATTR_TYPE_VALUE_INPUT;
	param[1].u.value.a = TA_AES_SIZE_256BIT;

	res = tee_invoke_func(tee, &arg, 2, param);
	if (res || arg.ret) {
		if (res) {
			printf("Failed: set algo tee_invoke_func error %x\n", res);
			goto out;
		}
		res = arg.ret;
		printf("Failed: set algo error, return %x\n", arg.ret);
		goto out;
	}
out:
	return res;
}

static int image_decrypt_via_optee(struct udevice *tee,
				   unsigned char *salt, u32 salt_len,
				   unsigned char *iv, u32 iv_len,
				   unsigned char *cipher, u32 cipher_len,
				   unsigned char *plain, u32 plain_len)
{
	u32 tee_session = 0;
	int res;

	res = session_init(tee, &tee_session);
       	if (res) {
		printf("Failed: init firmware ta\n");
		goto out;
	}

	res = set_algo(tee, tee_session);
	if (res) {
		printf("Failed: set algo\n");
		goto out;
	}

	res = set_salt(tee, tee_session, salt, salt_len);
	if (res) {
		printf("Failed: set salt\n");
		goto out;
	}

	res = set_iv(tee, tee_session, iv, iv_len);
	if (res) {
		printf("Failed: set iv\n");
		goto out;
	}

	res = set_buffer(tee, tee_session, cipher, cipher_len,
			 plain, plain_len);
	if (res) {
		printf("Failed: set buffer\n");
		goto out;
	}

out:
	session_deinit(tee, tee_session);

	return res;
}
#endif /* ifndef USE_HOSTCC */

int mtk_image_aes_decrypt(struct image_cipher_info *info,
			  const void *cipher, size_t cipher_len,
			  void **data, size_t *size)
{
#ifndef USE_HOSTCC
	struct udevice *tee;
	u32 salt_len, iv_len;
	unsigned char *salt, *iv;

	iv = (unsigned char*) info->iv;
	iv_len = info->cipher->iv_len;
	if (iv == NULL || iv_len == 0) {
		printf("iv is NULL or iv_len is 0\n");
		return -EINVAL;
	}

	if (get_salt(info->fit, &salt, &salt_len))
		return -EINVAL;

	*data = malloc(cipher_len);
	if (!*data) {
		printf("Can't allocate memory to decrypt\n");
		return -ENOMEM;
	}
	memset(*data, 0, cipher_len);
	*size = info->size_unciphered;


	tee = tee_find_device(NULL, NULL, NULL, NULL);
	if (!tee) {
		printf("Can't find tee device\n");
		return -ENODEV;
	}

	if(image_decrypt_via_optee(tee, salt, salt_len, iv, iv_len,
			(unsigned char *)cipher, cipher_len, *data, cipher_len)) {
		printf("Failed: image decryption via OP-TEE\n");
		return -EINVAL;
	}

#endif /* ifndef USE_HOSTCC */
	return 0;
}
