// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// operand-fraction.go
package expr

//https://www.youmath.it/lezioni/algebra-elementare/lezioni-di-algebra-e-aritmetica-per-scuole-medie/553-dalle-frazioni-a-numeri-decimali.html

import (
	"errors"
	"fmt"
	"math"
	"strconv"
	"strings"
)

type fraction struct {
	num, den int64
}

func newFraction(num, den int64) *fraction {
	/*	if den < 0 {
		den = -den
		num = -num
	}*/
	num, den = simplifyIntegers(num, den)
	return &fraction{num, den}
}

func float64ToFraction(f float64) (fract *fraction, err error) {
	var sign string
	intPart, decPart := math.Modf(f)
	if decPart < 0.0 {
		sign = "-"
		intPart = -intPart
		decPart = -decPart
	}
	dec := fmt.Sprintf("%.12f", decPart)
	s := fmt.Sprintf("%s%.f%s", sign, intPart, dec[1:])
	// fmt.Printf("S: '%s'\n",s)
	return makeGeneratingFraction(s)
}

// Based on https://cs.opensource.google/go/go/+/refs/tags/go1.22.3:src/math/big/rat.go;l=39
/*
func _float64ToFraction(f float64) (num, den int64, err error) {
	const expMask = 1<<11 - 1
	bits := math.Float64bits(f)
	mantissa := bits & (1<<52 - 1)
	exp := int((bits >> 52) & expMask)
	switch exp {
	case expMask: // non-finite
		err = errors.New("infite")
		return
	case 0: // denormal
		exp -= 1022
	default: // normal
		mantissa |= 1 << 52
		exp -= 1023
	}

	shift := 52 - exp

	// Optimization (?): partially pre-normalise.
	for mantissa&1 == 0 && shift > 0 {
		mantissa >>= 1
		shift--
	}

	if f < 0 {
		num = -int64(mantissa)
	} else {
		num = int64(mantissa)
	}
	den = int64(1)

	if shift > 0 {
		den = den << shift
	} else {
		num = num << (-shift)
	}
	return
}
*/

func makeGeneratingFraction(s string) (f *fraction, err error) {
	var num, den int64
	var sign int64 = 1
	var parts []string
	if len(s) == 0 {
		goto exit
	}
	if s[0] == '-' {
		sign = int64(-1)
		s = s[1:]
	} else if s[0] == '+' {
		s = s[1:]
	}
	if strings.HasSuffix(s, "()") {
		s = s[0 : len(s)-2]
	}
	parts = strings.SplitN(s, ".", 2)
	if num, err = strconv.ParseInt(parts[0], 10, 64); err != nil {
		return
	}
	if len(parts) == 1 {
		f = newFraction(sign*num, 1)
	} else if len(parts) == 2 {
		subParts := strings.SplitN(parts[1], "(", 2)
		if len(subParts) == 1 {
			den = 1
			dec := parts[1]
			lsd := len(dec)
			for i := lsd - 1; i >= 0 && dec[i] == '0'; i-- {
				lsd--
			}
			for _, c := range dec[0:lsd] {
				if c < '0' || c > '9' {
					return nil, errExpectedGot("fract", "digit", c)
				}
				num = num*10 + int64(c-'0')
				den = den * 10
			}
			f = newFraction(sign*num, den)
		} else if len(subParts) == 2 {
			sub := num
			mul := int64(1)
			for _, c := range subParts[0] {
				if c < '0' || c > '9' {
					return nil, errExpectedGot("fract", "digit", c)
				}
				num = num*10 + int64(c-'0')
				sub = sub*10 + int64(c-'0')
				mul *= 10
			}
			if len(subParts) == 2 {
				if s[len(s)-1] != ')' {
					goto exit
				}
				p := subParts[1][0 : len(subParts[1])-1]
				for _, c := range p {
					if c < '0' || c > '9' {
						return nil, errExpectedGot("fract", "digit", c)
					}
					num = num*10 + int64(c-'0')
					den = den*10 + 9
				}
				den *= mul
			}
			num -= sub
			f = newFraction(sign*num, den)
		}
	}
exit:
	if f == nil {
		err = errors.New("bad syntax")
	}
	return
}

func (f *fraction) toFloat() float64 {
	return float64(f.num) / float64(f.den)
}

func (f *fraction) String() string {
	return f.ToString(0)
}

func (f *fraction) ToString(opt FmtOpt) string {
	var sb strings.Builder
	if opt&MultiLine == 0 {
		sb.WriteString(fmt.Sprintf("%d|%d", f.num, f.den))
	} else {
		var s, num string
		if f.num < 0 && opt&TTY == 0 {
			num = strconv.FormatInt(-f.num, 10)
			s = "-"
		} else {
			num = strconv.FormatInt(f.num, 10)
		}
		den := strconv.FormatInt(f.den, 10)
		size := max(len(num), len(den))
		if opt&TTY != 0 {
			sb.WriteString(fmt.Sprintf("\x1b[4m%[1]*s\x1b[0m\n", -size, fmt.Sprintf("%[1]*s", (size+len(num))/2, s+num)))
		} else {
			if len(s) > 0 {
				sb.WriteString("  ")
			}
			sb.WriteString(fmt.Sprintf("%[1]*s", -size, fmt.Sprintf("%[1]*s", (size+len(num))/2, num)))
			sb.WriteByte('\n')
			if len(s) > 0 {
				sb.WriteString(s)
				sb.WriteByte(' ')
			}
			sb.WriteString(strings.Repeat("-", size))
			sb.WriteByte('\n')
			if len(s) > 0 {
				sb.WriteString("  ")
			}
		}
		sb.WriteString(fmt.Sprintf("%[1]*s", -size, fmt.Sprintf("%[1]*s", (size+len(den))/2, den)))
	}

	return sb.String()
}

func (f *fraction) TypeName() string {
	return "fraction"
}

// -------- fraction term
func newFractionTerm(tk *Token) *term {
	return &term{
		tk:       *tk,
		parent:   nil,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priFraction,
		evalFunc: evalFraction,
	}
}

// -------- eval func
func evalFraction(ctx ExprContext, self *term) (v any, err error) {
	var numValue, denValue any
	var num, den int64
	var ok bool

	if numValue, denValue, err = self.evalInfix(ctx); err != nil {
		return
	}
	if num, ok = numValue.(int64); !ok {
		err = fmt.Errorf("numerator must be integer, got %T (%v)", numValue, numValue)
		return
	}
	if den, ok = denValue.(int64); !ok {
		err = fmt.Errorf("denominator must be integer, got %T (%v)", denValue, denValue)
		return
	}
	if den == 0 {
		err = errors.New("division by zero")
		return
	}

	if den < 0 {
		den = -den
		num = -num
	}
	g := gcd(num, den)
	num = num / g
	den = den / g
	if den == 1 {
		v = num
	} else {
		v = &fraction{num, den}
	}
	return
}

func gcd(a, b int64) (g int64) {
	if a < 0 {
		a = -a
	}
	if b < 0 {
		b = -b
	}
	if a < b {
		a, b = b, a
	}
	r := a % b
	for r > 0 {
		a, b = b, r
		r = a % b
	}
	g = b
	return
}

func lcm(a, b int64) (l int64) {
	g := gcd(a, b)
	l = a * b / g
	return
}

func sumFract(f1, f2 *fraction) (sum *fraction) {
	m := lcm(f1.den, f2.den)
	sum = newFraction(f1.num*(m/f1.den)+f2.num*(m/f2.den), m)
	return
}

func mulFract(f1, f2 *fraction) (prod *fraction) {
	prod = newFraction(f1.num*f2.num, f1.den*f2.den)
	return
}

func anyToFract(v any) (f *fraction, err error) {
	var ok bool
	if f, ok = v.(*fraction); !ok {
		if n, ok := v.(int64); ok {
			f = intToFraction(n)
		}
	}
	if f == nil {
		err = errExpectedGot("fract", typeFraction, v)
	}
	return
}

func anyPairToFract(v1, v2 any) (f1, f2 *fraction, err error) {
	if f1, err = anyToFract(v1); err != nil {
		return
	}
	if f2, err = anyToFract(v2); err != nil {
		return
	}
	return
}

func sumAnyFract(af1, af2 any) (sum any, err error) {
	var f1, f2 *fraction
	if f1, f2, err = anyPairToFract(af1, af2); err != nil {
		return
	}
	f := sumFract(f1, f2)
	if f.num == 0 {
		sum = 0
	} else {
		sum = simplifyFraction(f)
	}
	return
}

func subAnyFract(af1, af2 any) (sum any, err error) {
	var f1, f2 *fraction
	if f1, f2, err = anyPairToFract(af1, af2); err != nil {
		return
	}
	f2.num = -f2.num
	f := sumFract(f1, f2)
	if f.num == 0 {
		sum = 0
	} else {
		sum = simplifyFraction(f)
	}
	return
}

func mulAnyFract(af1, af2 any) (prod any, err error) {
	var f1, f2 *fraction
	if f1, f2, err = anyPairToFract(af1, af2); err != nil {
		return
	}
	if f1.num == 0 || f2.num == 0 {
		prod = 0
	} else {
		f := &fraction{f1.num * f2.num, f1.den * f2.den}
		prod = simplifyFraction(f)
	}
	return
}

func divAnyFract(af1, af2 any) (quot any, err error) {
	var f1, f2 *fraction
	if f1, f2, err = anyPairToFract(af1, af2); err != nil {
		return
	}
	if f2.num == 0 {
		err = errors.New("division by zero")
		return
		return
	}
	if f1.num == 0 || f2.den == 0 {
		quot = 0
	} else {
		f := &fraction{f1.num * f2.den, f1.den * f2.num}
		quot = simplifyFraction(f)
	}
	return
}

func simplifyFraction(f *fraction) (v any) {
	f.num, f.den = simplifyIntegers(f.num, f.den)
	if f.den == 1 {
		v = f.num
	} else {
		v = &fraction{f.num, f.den}
	}
	return v
}

func simplifyIntegers(num, den int64) (a, b int64) {
	if num == 0 {
		return 0, 1
	}
	if den == 0 {
		panic("fraction with denominator == 0")
	}
	if den < 0 {
		den = -den
		num = -num
	}
	g := gcd(num, den)
	a = num / g
	b = den / g
	return
}

func intToFraction(n int64) *fraction {
	return &fraction{n, 1}
}

func isFraction(v any) (ok bool) {
	_, ok = v.(*fraction)
	return ok
}

// init
func init() {
	registerTermConstructor(SymVertBar, newFractionTerm)
}