// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// funcs-math.go
package expr

import (
	"fmt"
	"io"
)

func checkNumberParamExpected(funcName string, paramValue any, paramPos, level, subPos int) (err error) {
	if !(IsNumber(paramValue) || isFraction(paramValue)) /*|| isList(paramValue)*/ {
		err = fmt.Errorf("%s(): param nr %d (%d in %d) has wrong type %T, number expected",
			funcName, paramPos+1, subPos+1, level, paramValue)
	}
	return
}

func doAdd(ctx ExprContext, name string, it Iterator, count, level int) (result any, err error) {
	var sumAsFloat, sumAsFract bool
	var floatSum float64 = 0.0
	var intSum int64 = 0
	var fractSum *fraction
	var v any

	level++

	for v, err = it.Next(); err == nil; v, err = it.Next() {
		if list, ok := v.(*ListType); ok {
			v = NewListIterator(list, nil)
		}
		if subIter, ok := v.(Iterator); ok {
			if v, err = doAdd(ctx, name, subIter, count, level); err != nil {
				break
			}
			if extIter, ok := v.(ExtIterator); ok && extIter.HasOperation(cleanName) {
				if _, err = extIter.CallOperation(cleanName, nil); err != nil {
					return
				}
			}
		} else if err = checkNumberParamExpected(name, v, count, level, it.Index()); err != nil {
			break
		}
		count++

		if !sumAsFloat {
			if IsFloat(v) {
				sumAsFloat = true
				if sumAsFract {
					floatSum = fractSum.toFloat()
				} else {
					floatSum = float64(intSum)
				}
			} else if !sumAsFract && isFraction(v) {
				fractSum = newFraction(intSum, 1)
				sumAsFract = true
			}
		}

		if sumAsFloat {
			floatSum += numAsFloat(v)
		} else if sumAsFract {
			var item *fraction
			var ok bool
			if item, ok = v.(*fraction); !ok {
				iv, _ := v.(int64)
				item = newFraction(iv, 1)
			}
			fractSum = sumFract(fractSum, item)
		} else {
			iv, _ := v.(int64)
			intSum += iv
		}
	}
	if err == nil || err == io.EOF {
		err = nil
		if sumAsFloat {
			result = floatSum
		} else if sumAsFract {
			result = fractSum
		} else {
			result = intSum
		}
	}
	return
}

func addFunc(ctx ExprContext, name string, args []any) (result any, err error) {
	result, err = doAdd(ctx, name, NewArrayIterator(args), 0, -1)
	return
}

func doMul(ctx ExprContext, name string, it Iterator, count, level int) (result any, err error) {
	var mulAsFloat, mulAsFract bool
	var floatProd float64 = 1.0
	var intProd int64 = 1
	var fractProd *fraction
	var v any

	level++
	for v, err = it.Next(); err == nil; v, err = it.Next() {
		if list, ok := v.(*ListType); ok {
			v = NewListIterator(list, nil)
		}
		if subIter, ok := v.(Iterator); ok {
			if v, err = doMul(ctx, name, subIter, count, level); err != nil {
				break
			}
			if extIter, ok := v.(ExtIterator); ok && extIter.HasOperation(cleanName) {
				if _, err = extIter.CallOperation(cleanName, nil); err != nil {
					return
				}
			}
		} else {
			if err = checkNumberParamExpected(name, v, count, level, it.Index()); err != nil {
				break
			}
		}
		count++

		if !mulAsFloat {
			if IsFloat(v) {
				mulAsFloat = true
				if mulAsFract {
					floatProd = fractProd.toFloat()
				} else {
					floatProd = float64(intProd)
				}
			} else if !mulAsFract && isFraction(v) {
				fractProd = newFraction(intProd, 1)
				mulAsFract = true
			}
		}

		if mulAsFloat {
			floatProd *= numAsFloat(v)
		} else if mulAsFract {
			var item *fraction
			var ok bool
			if item, ok = v.(*fraction); !ok {
				iv, _ := v.(int64)
				item = newFraction(iv, 1)
			}
			fractProd = mulFract(fractProd, item)
		} else {
			iv, _ := v.(int64)
			intProd *= iv
		}
	}
	if err == nil || err == io.EOF {
		err = nil
		if mulAsFloat {
			result = floatProd
		} else if mulAsFract {
			result = fractProd
		} else {
			result = intProd
		}
	}
	return
}

func mulFunc(ctx ExprContext, name string, args []any) (result any, err error) {
	result, err = doMul(ctx, name, NewArrayIterator(args), 0, -1)
	return
}

func ImportMathFuncs(ctx ExprContext) {
	ctx.RegisterFunc("add", &simpleFunctor{f: addFunc}, 0, -1)
	ctx.RegisterFunc("mul", &simpleFunctor{f: mulFunc}, 0, -1)
}

func init() {
	registerImport("math.arith", ImportMathFuncs, "Functions add() and mul()")
}