xmm_fma.c   [plain text]


/*
 *  xmm_fma.c
 *  xmmLibm
 *
 *  Created by Ian Ollmann on 8/8/05.
 *  Copyright 2005 Apple Computer Inc. All rights reserved.
 *
 */


#include "xmmLibm_prefix.h"

#include <math.h>

//For any rounding mode, we can calculate A + B exactly as a head to tail result as:
//
//  Rhi = A + B     (A has a larger magnitude than B)
//  Rlo = B - (Rhi - A);
//  
//  Rhi is rounded to the current rounding mode
//  Rlo represents the next 53+ bits of precision

//returns carry bits that don't fit into A
static inline long double extended_accum( long double *A, long double B ) ALWAYS_INLINE;
static inline long double extended_accum( long double *A, long double B ) 
{
    long double r = *A + B;
    long double carry = B - ( r - *A );
    *A = r;
    return carry;
}


double fma( double a, double b, double c )
{
    static const xUInt64 mask26 = { 0xFFFFFFFFFC000000ULL, 0 };
    double ahi, bhi;

    //break a, b and c into high and low components
    //The high components have 26 bits of precision
    //The low components have 27 bits of precision
    xDouble xa = DOUBLE_2_XDOUBLE(a);
    xDouble xb = DOUBLE_2_XDOUBLE(b);

    xa = _mm_and_pd( xa, (xDouble) mask26 );
    xb = _mm_and_pd( xb, (xDouble) mask26 );
    
    ahi = XDOUBLE_2_DOUBLE( xa );
    bhi = XDOUBLE_2_DOUBLE( xb );

    //double precision doesn't have enough precision for the next part. 
    //so we abandond it and go to extended.
    ///
    //The problem is that the intermediate multiplication product needs to be infinitely
    //precise. While we can fix the mantissa part of the infinite precision problem,
    //we can't deal with the case where the product is outside the range of the
    //representable double precision values. Extended precision allows us to solve
    //that problem, since all double values and squares of double values are normalized
    //numbers in extended precision
    long double Ahi = ahi;
    long double Bhi = bhi;
    long double Alo = (long double) a - Ahi;
    long double Blo = (long double) b - Bhi;
    long double C = c;
    
    //The result of A * B is now exactly:
    //
    //  B1 + Ahi*Bhi + Alo*Bhi + Ahi*Blo + Alo*Blo 
    //  all of these intermediate terms have either 52 or 53 bits of precision and are exact
    long double AhiBhi = Ahi * Bhi;        //52 bits 
    long double AloBhi = Alo * Bhi;        //53 bits  
    long double AhiBlo = Ahi * Blo;        //53 bits  
    long double AloBlo = Alo * Blo;        //54 bits   
    
    //accumulate the results into two head/tail buffers. This is infinitely precise.
    //no effort is taken to make sure that a0 and a1 are actually head to tail
    long double a0 = AloBlo;
    long double a1 = extended_accum( &a0, AhiBlo );
    a1 += extended_accum( &a0, AloBhi );
    a1 += extended_accum( &a0, AhiBhi );

    //Add C. If C has greater magnitude than a0, we need to swap them
    if( fabsl( C ) > fabsl( a0 ) )
    {
        long double temp = C;
        C = a0;
        a0 = temp;
    }

    //this will probably overflow, but we have 128 bits of precision here, which should mean we are covered. 
    a1 += extended_accum( &a0, C );
    
    //push overflow in a1 back into a0. This should give us the correctly rounded result
    a1 = extended_accum( &a0, a1 );

    return a0;
}




float fmaf( float a, float b, float c )
{
    xDouble xa = FLOAT_2_XDOUBLE( a );
    xDouble xb = FLOAT_2_XDOUBLE( b );
    xDouble xc = FLOAT_2_XDOUBLE( c );

    xa = _mm_mul_sd( xa, xb );                     //exact
    xa = _mm_add_sd( xa, xc );                     //inexact

    return XDOUBLE_2_FLOAT( xa );                //double rounding, alas
}


long double fmal( long double a, long double b, long double c )
{
	/*****************
	 Edge cases, from Ian's code.
	*****************/
	
	union{ long double ld; struct{ uint64_t mantissa; int16_t sexp; }parts; }ua = {a};
	union{ long double ld; struct{ uint64_t mantissa; int16_t sexp; }parts; }ub = {b};
	
	int16_t sign = (ua.parts.sexp ^ ub.parts.sexp) & 0x8000;
	int32_t aexp = (ua.parts.sexp & 0x7fff);
	int32_t bexp = (ub.parts.sexp & 0x7fff);
	int32_t exp  =  aexp + bexp - 16383;
	
	//deal with NaN
	if( a != a )    return a;
	if( b != b )    return b;
	if( c != c )    return c;
	
	// a = ° | b = °
	if ((aexp == 0x7fff) || (bexp == 0x7fff)) {				// We've already dealt with NaN, so this is only true 
																										// if one of a and b is an inf.
		if (( b == 0.0L ) || ( a == 0.0L)) return a*b;	// Return NaN if a = ±°, b = 0, c ­ NaN (or a = 0, b = °)
		if ( __builtin_fabsl(c) == __builtin_infl() ) {
			if ( sign & 0x8000 ) {
				if ( c > 0 )
					return c - __builtin_infl();							// Return NaN if a = ±°, c = -a
				else
					return c;																	// Return ±inf if a = c = ±°.
			} else {
				if ( c < 0 )
					return c + __builtin_infl();							// Return NaN if a = ±°, c = -a
				else
					return c;																	// Return ±inf if a = c = ±°.
			}
			if ( sign & 0x8000 ) return -__builtin_infl();
			else return __builtin_infl();
		}
	}
	
	// c = °
	if ( __builtin_fabsl(c) == __builtin_inf() ) return c; // a,b at this point are finite, c is infinite.

	
	if( exp < -16382 - 63 ) //sub denormal
		return c;
	
	/*****************
	 Computation of a*b + c. scanon, Jan 2007
	 
	 The whole game we're playing here only works in round-to-nearest.
	 *****************/
	
	long double ahi, alo, bhi, blo, phi, plo, xhi, xlo, yhi, ylo, tmp;
	
	//  split a,b into high and low parts.
	static const uint64_t split32_mask = 0xffffffff00000000ULL;
	ua.parts.mantissa = ua.parts.mantissa & split32_mask;
	ahi = ua.ld;								alo = a - ahi;
	ub.parts.mantissa = ub.parts.mantissa & split32_mask;
	bhi = ub.ld;								blo = b - bhi;
	
	// compute the product of a and b as a sum phi + plo.  This is exact.
	phi = a * b;
	
	// In case of overflow, stop here and return phi.  This will need to be changed
	// in order to have a fully C99 fmal.
	if (__builtin_fabsl(phi) == __builtin_infl()) {
		return phi; // We know that c != inf or nan, so phi is the correct result.
	}
	
	plo = (((ahi * bhi - phi) + ahi*blo) + alo*bhi) + alo*blo;
	
	// compute plo + c = xhi + xlo where (xhi,xlo) is head-tail.
	xhi = plo + c;
	if (__builtin_fabsl(plo) > __builtin_fabsl(c)) {
		tmp = xhi - plo;
		xlo = c - tmp;
	} else {
		tmp = xhi - c;
		xlo = plo - tmp;
	}
	
	// Special case: xlo == 0, hence return phi + xhi:
	if (xlo == 0.0L) return phi + xhi;
	
	// At this point we know that the meaningful bits of phi and xhi are entirely to the
	// left (larger) side of the meaningful bits of xlo, and that our result is 
	// round(phi + xhi + xlo).
	yhi = phi + xhi;
	if (__builtin_fabsl(phi) > __builtin_fabsl(xhi)) {
		tmp = yhi - phi;
		ylo = xhi - tmp;
	} else {
		tmp = yhi - xhi;
		ylo = phi - tmp;
	}
	
	// Handle the special case that one of yhi or ylo is zero.
	// If yhi == 0, then ylo is also zero, so yhi + xlo = xlo is the appropriate result.
	// If ylo == 0, then yhi + xlo is the appropriate result.
	if ((yhi == 0.0L) || (ylo == 0.0L)) return yhi + xlo;
	
	// Now we have that (in terms of meaningful bits)
	// yhi is strictly bigger than ylo is strictly bigger than xlo.
	// Additionally, our desired result is round(yhi + ylo + xlo).
	
	// The only way for xlo to affect rounding (in round-to-nearest) is for ylo to
	// be exactly half an ulp of yhi.  Test for the value of the mantissa of ylo;
	// this is not the same condition, but getting this wrong can't affect the rounding.
	union{ long double ld; struct{ uint64_t mantissa; int16_t sexp; }parts; } uylo = {ylo};
	if (uylo.parts.mantissa == 0x8000000000000000ULL) {
		return yhi + (ylo + copysignl(0.5L*ylo, xlo));
	}
	// In the other case, xlo has no affect on the final result, so just return yhi + ylo
	else {
		return yhi + ylo;
	}
	
	/*
	 Code that Ian wrote to do this in integer, never finished.
	 
	 //mantissa product
    //  hi(a.hi*b.hi)   lo(a.hi*b.hi)
    //                  hi(a.hi*b.lo)   lo(a.hi*b.lo)
    //                  hi(a.lo*b.hi)   lo(a.lo*b.hi)
    // +                                hi(a.lo*b.lo)   lo(a.lo*b.lo)
    // --------------------------------------------------------------

    uint32_t    ahi = ua.parts.mantissa >> 32;
    uint32_t    bhi = ub.parts.mantissa >> 32;
    uint32_t    alo = ua.parts.mantissa & 0xFFFFFFFFULL;
    uint32_t    blo = ub.parts.mantissa & 0xFFFFFFFFULL;

    uint64_t    templl, temphl, templh, temphh;
    xUInt64     l, h, a;
    
    templl = (uint64_t) alo * (uint64_t) blo;
    temphl = (uint64_t) ahi * (uint64_t) blo;
    templh = (uint64_t) alo * (uint64_t) bhi;
    temphh = (uint64_t) ahi * (uint64_t) bhi;

    l = _mm_cvtsi32_si128( (int32_t) templl );          templl >>= 32;

    temphl +=  templl;
    temphl +=  templh & 0xFFFFFFFFULL;
    h = _mm_cvtsi32_si128( (int32_t) temphl);           temphl >>= 32;
    a = _mm_unpacklo_epi32( l, h );
    

    temphh += templh >> 32;
    temphh += temphl;

    a = _mm_cvtsi32_si128( (int32_t) temphh );          temphh >>= 32;
    h = _mm_cvtsi32_si128( (int32_t) temphh);           
    a = _mm_unpacklo_epi32( a, h );
    l = _mm_unpacklo_epi64( l, a );
    
    union
    {
        xUInt64     v;
        uint64_t    u[2];
    }z = { l };

    long double lo = (long double) temphh.
	 */
}