// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// operator-sum.go
package expr

import (
	"fmt"
	"slices"
)

//-------- plus term

func newPlusTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priSum,
		evalFunc: evalPlus,
	}
}

func evalPlus(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}

	if (isString(leftValue) && isNumberString(rightValue)) || (isString(rightValue) && isNumberString(leftValue)) {
		v = fmt.Sprintf("%v%v", leftValue, rightValue)
	} else if isNumber(leftValue) && isNumber(rightValue) {
		if isFloat(leftValue) || isFloat(rightValue) {
			v = numAsFloat(leftValue) + numAsFloat(rightValue)
		} else {
			leftInt, _ := leftValue.(int64)
			rightInt, _ := rightValue.(int64)
			v = leftInt + rightInt
		}
	} else if isList(leftValue) || isList(rightValue) {
		var leftList, rightList []any
		var ok bool
		if leftList, ok = leftValue.([]any); !ok {
			leftList = []any{leftValue}
		}
		if rightList, ok = rightValue.([]any); !ok {
			rightList = []any{rightValue}
		}
		sumList := make([]any, 0, len(leftList)+len(rightList))
		for _, item := range leftList {
			sumList = append(sumList, item)
		}
		for _, item := range rightList {
			sumList = append(sumList, item)
		}
		v = sumList
	} else {
		err = self.errIncompatibleTypes(leftValue, rightValue)
	}
	return
}

//-------- minus term

func newMinusTerm(tk *Token) (inst *term) {
	return &term{
		tk: *tk,
		// class:    classOperator,
		// kind:     kindUnknown,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priSum,
		evalFunc: evalMinus,
	}
}

func evalMinus(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}

	if isNumber(leftValue) && isNumber(rightValue) {
		if isFloat(leftValue) || isFloat(rightValue) {
			v = numAsFloat(leftValue) - numAsFloat(rightValue)
		} else {
			leftInt, _ := leftValue.(int64)
			rightInt, _ := rightValue.(int64)
			v = leftInt - rightInt
		}
	} else if isList(leftValue) && isList(rightValue) {
		leftList, _ := leftValue.([]any)
		rightList, _ := rightValue.([]any)
		diffList := make([]any, 0, len(leftList)-len(rightList))
		for _, item := range leftList {
			if slices.Index(rightList, item) < 0 {
				diffList = append(diffList, item)
			}
		}
		v = diffList
	} else {
		err = self.errIncompatibleTypes(leftValue, rightValue)
	}
	return
}

// init
func init() {
	registerTermConstructor(SymPlus, newPlusTerm)
	registerTermConstructor(SymMinus, newMinusTerm)
}