// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// operator-rel.go
package expr

import "reflect"

//-------- equal term

func newEqualTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priRelational,
		evalFunc: evalEqual,
	}
}

type deepFuncTemplate func(a, b any) (eq bool, err error)

func equals(a, b any, deepCmp deepFuncTemplate) (eq bool, err error) {
	if isNumOrFract(a) && isNumOrFract(b) {
		if IsNumber(a) && IsNumber(b) {
			if IsInteger(a) && IsInteger(b) {
				li, _ := a.(int64)
				ri, _ := b.(int64)
				eq = li == ri
			} else {
				eq = numAsFloat(a) == numAsFloat(b)
			}
		} else {
			var cmp int
			if cmp, err = cmpAnyFract(a, b); err == nil {
				eq = cmp == 0
			}
		}
	} else if deepCmp != nil && IsList(a) && IsList(b) {
		eq, err = deepCmp(a, b)
	} else {
		eq = reflect.DeepEqual(a, b)
	}

	return
}

func evalEqual(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}

	v, err = equals(leftValue, rightValue, nil)
	return
}

//-------- not equal term

func newNotEqualTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priRelational,
		evalFunc: evalNotEqual,
	}
}

func evalNotEqual(ctx ExprContext, self *term) (v any, err error) {
	if v, err = evalEqual(ctx, self); err == nil {
		b, _ := toBool(v)
		v = !b
	}
	return
}

//-------- less term

func newLessTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priRelational,
		evalFunc: evalLess,
	}
}

func lessThan(self *term, a, b any) (isLess bool, err error) {
	if isNumOrFract(a) && isNumOrFract(b) {
		if IsNumber(a) && IsNumber(b) {
			if IsInteger(a) && IsInteger(b) {
				li, _ := a.(int64)
				ri, _ := b.(int64)
				isLess = li < ri
			} else {
				isLess = numAsFloat(a) < numAsFloat(b)
			}
		} else {
			var cmp int
			if cmp, err = cmpAnyFract(a, b); err == nil {
				isLess = cmp < 0
			}
		}
	} else if IsString(a) && IsString(b) {
		ls, _ := a.(string)
		rs, _ := b.(string)
		isLess = ls < rs
		// Inclusion test
	} else if IsList(a) && IsList(b) {
		aList, _ := a.(*ListType)
		bList, _ := b.(*ListType)
		isLess = len(*aList) < len(*bList) && bList.contains(aList)
	} else {
		err = self.errIncompatibleTypes(a, b)
	}
	return
}

func evalLess(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}
	v, err = lessThan(self, leftValue, rightValue)
	return
}

//-------- less or equal term

func newLessEqualTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priRelational,
		evalFunc: evalLessEqual,
	}
}

func lessThanOrEqual(self *term, a, b any) (isLessEq bool, err error) {
	if isLessEq, err = lessThan(self, a, b); err == nil {
		if !isLessEq {
			if IsList(a) && IsList(b) {
				isLessEq, err = sameContent(a, b)
			} else {
				isLessEq, err = equals(a, b, nil)
			}
		}
	}
	return
}

func evalLessEqual(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}

	v, err = lessThanOrEqual(self, leftValue, rightValue)

	return
}

//-------- greater term

func newGreaterTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priRelational,
		evalFunc: evalGreater,
	}
}

func evalGreater(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}

	v, err = lessThan(self, rightValue, leftValue)
	return
}

//-------- greater or equal term

func newGreaterEqualTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priRelational,
		evalFunc: evalGreaterEqual,
	}
}

func evalGreaterEqual(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}

	v, err = lessThanOrEqual(self, rightValue, leftValue)
	return
}

// init
func init() {
	registerTermConstructor(SymDoubleEqual, newEqualTerm)
	registerTermConstructor(SymNotEqual, newNotEqualTerm)
	registerTermConstructor(SymLess, newLessTerm)
	registerTermConstructor(SymLessOrEqual, newLessEqualTerm)
	registerTermConstructor(SymGreater, newGreaterTerm)
	registerTermConstructor(SymGreaterOrEqual, newGreaterEqualTerm)
}