/*
 * mp_mont.c
 *
 * This code is in the public domain. I would appreciate bug reports and
 * enhancements.
 *
 * Duncan S Wong <swong@ieee.org>
 *
 * Dec 14, 2000 - Initial Version
 */
#include <PalmOS.h>
#include "mp.h"
#include "mp_priv.h"

// Allocate a memory chunk for the Montgeomery data structure.
MP_MONT_CTX *MP_MONT_CTX_new(void)
{
MP_MONT_CTX *ret;

	if ((ret=(MP_MONT_CTX *) MemPtrNew(sizeof(MP_MONT_CTX))) == NULL)
		return(NULL);
	ret->n    = 0;
	ret->mi   = (DIGIT)0;
	ret->R_m  = MP_new();
	ret->R2_m = MP_new();
	ret->m    = MP_new();
	if ((ret->R_m == NULL) || (ret->R2_m == NULL) || (ret->m == NULL)) {
		MP_MONT_CTX_free(ret);
		return(NULL);
  }

	return(ret);
}


// Clear and free the memory chunk for Montgomery multiplication.
void MP_MONT_CTX_free(MP_MONT_CTX *mont)
{
	mont->n = 0;
  mont->mi = (DIGIT)0;
	if (mont->R_m != NULL) MP_clear_free(mont->R_m);
	if (mont->R2_m != NULL) MP_clear_free(mont->R2_m);
	if (mont->m != NULL) MP_clear_free(mont->m);
	MemPtrFree(mont);
}


// Set the parameters of the data structure for Montgomery multiplication.
//
// Note : - modulus must be odd and positive.
Int16 MP_MONT_CTX_set(MP_MONT_CTX *mont, INT *modulus, MP_CTX *ctx)
{
INT *b, *t1, *t2;
Int16 ret = 0;

  if (modulus == NULL) return(ret);
  if (!MP_is_odd(modulus)) return(ret);

  t1 = ctx->t[ctx->tos++];
  t2 = ctx->t[ctx->tos++];
  b  = ctx->t[ctx->tos++];
  if(b->max <= 2) if(!MP_alloc(b, 2*DIGIT_BITS)) goto err;

  b->top=2;
  b->neg=0;
  b->d[0]=(DIGIT)0;
  b->d[1]=(DIGIT)1;
  if(mont == NULL) if((mont = MP_MONT_CTX_new()) == NULL) goto err;

  mont->n = modulus->top;
  if(MP_copy(mont->m, modulus) == NULL) goto err;

  if(!MP_alloc(t1, (mont->n + 1)*DIGIT_BITS)) goto err;
  t1->top = mont->n + 1;
  memset(t1->d, 0, (t1->top - 1)*sizeof(DIGIT));
  t1->d[t1->top-1] = (DIGIT)1;
  if(!MP_div(NULL, mont->R_m, t1, mont->m, ctx)) goto err;
  if(!MP_sqr(t2, t1)) goto err;
  if(!MP_div(NULL, mont->R2_m, t2, mont->m, ctx)) goto err;

  // compute mi = -m^{-1} mod b where b = 2^DIGIT_BITS
  if(!MP_div(NULL, t1, mont->m, b, ctx)) goto err;
  if(!MP_sub(t1, b, t1)) goto err;
  if(t2 != NULL) MP_clear_free(t2);
  t2 = MP_mod_inverse(t1, b, ctx);
  if(t2 == NULL) goto err;
  mont->mi = t2->d[0];

  ret = 1;
err:
  ctx->tos -= 3;
	return(ret);
}


// Set r to abR^{-1} (mod m) where
// information of R and m are in (MP_MONT_CTX *) mont.
//
// Montgomery Multiplication
// (Algorithm 14.36 on p.602 of HAC)
//
// Note : - We assume that a, b < mont->m.
//        - r can be a or b.
//        - As R = b^n, the modulus (mont->m) need to be odd.
Int16 MP_mont_mul(INT *r, INT *a, INT *b, MP_MONT_CTX *mont, MP_CTX *ctx)
{
INT *A, *t1;
register DIGIT u, carry;
Int16 i, ret=0;

  A    = ctx->t[ctx->tos++];
  t1   = ctx->t[ctx->tos++];
  if(!MP_alloc(t1, (mont->m->top)*DIGIT_BITS)) goto err;

  MP_zero(A);
  for(i=0; i < a->top; i++) {
    // compute u = (a0 + xi y0) mi (mod b)
    // only need the first DIGIT, so I don't care the overflow
    u = ((A->d[0] + a->d[i] * b->d[0]) * mont->mi) & DIGIT_MASK;
    // compute A = (A + xi y + u m) / b
    t1->top = b->top;
    carry = MP_mul_digit(t1->d, b->d, b->top, a->d[i]);
    if(carry) {t1->d[t1->top] = carry; t1->top++; }
    MP_add(A, A, t1);
    t1->top = mont->m->top;
    carry = MP_mul_digit(t1->d, mont->m->d, mont->m->top, u);
    if(carry) {t1->d[t1->top] = carry; t1->top++; }
    MP_add(A, A, t1);
    MP_rshift(A, A, DIGIT_BITS);
  }
  for(i=a->top; i<mont->m->top; i++) {
    // compute u = a0 mi (mod b)
    u = (A->d[0] * mont->mi) & DIGIT_MASK;
    // compute A = (A + u m) / b
    t1->top = mont->m->top;
    carry = MP_mul_digit(t1->d, mont->m->d, mont->m->top, u);
    if(carry) {t1->d[t1->top] = carry; t1->top++; }
    MP_add(A, A, t1);
    MP_rshift(A, A, DIGIT_BITS);
  }

  if(MP_cmp(A, mont->m) > 0)
    MP_sub(A, A, mont->m);

  if(MP_copy(r, A) == NULL) goto err;
  ret = 1;
err:
	ctx->tos -= 2;
	return(ret);
}


// Set r to a^b mod m using Montgomery exponentiation.
// (Algorithm 14.94 on p.620 of HAC)
//
// Note : - r can be a.
//        - m is in (MP_MONT_CTX *)mont.
//        - m must be odd.
Int16 MP_mont_exp(INT *r, INT *a, INT *e, MP_MONT_CTX *mont, MP_CTX *ctx)
{
Int16 i, bits, ret=0;
INT *A, *t1;

	t1 = ctx->t[ctx->tos++];
	A = ctx->t[ctx->tos++];

  i = MP_cmp(a, mont->m);
  if(!i) { MP_zero(r); ret=1; goto err; }
  else if(i > 0) { if(!MP_div(NULL, A, a, mont->m, ctx)) goto err; }
  else { if(MP_copy(A, a) == NULL) goto err; }
  if(!MP_mont_mul(t1, A, mont->R2_m, mont, ctx)) goto err;
  if(MP_copy(A, mont->R_m) == NULL) goto err;

	bits = MP_num_bits(e) - 1;
  for (i=bits; i>=0; i--) {
    if(!MP_mont_mul(A, A, A, mont, ctx)) goto err;
    if(MP_is_bit_set(e, i))
      if(!MP_mont_mul(A, A, t1, mont, ctx)) goto err;
  }
  MP_one(t1);
  if(!MP_mont_mul(A, A, t1, mont, ctx)) goto err;

  if(MP_copy(r, A) == NULL) goto err;
	ret=1;
err:
	ctx->tos -= 2;
	return(ret);
}
