// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// operator-sum.go
package expr

import (
	"fmt"
	"slices"
)

//-------- plus term

func newPlusTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priSum,
		evalFunc: evalPlus,
	}
}

func evalPlus(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}

	if (IsString(leftValue) && isNumberString(rightValue)) || (IsString(rightValue) && isNumberString(leftValue)) {
		v = fmt.Sprintf("%v%v", leftValue, rightValue)
	} else if IsNumber(leftValue) && IsNumber(rightValue) {
		if IsFloat(leftValue) || IsFloat(rightValue) {
			v = numAsFloat(leftValue) + numAsFloat(rightValue)
		} else {
			leftInt, _ := leftValue.(int64)
			rightInt, _ := rightValue.(int64)
			v = leftInt + rightInt
		}
	} else if IsList(leftValue) && IsList(rightValue) {
		var leftList, rightList *ListType
		leftList, _ = leftValue.(*ListType)
		rightList, _ = rightValue.(*ListType)

		sumList := make(ListType, 0, len(*leftList)+len(*rightList))
		for _, item := range *leftList {
			sumList = append(sumList, item)
		}
		for _, item := range *rightList {
			sumList = append(sumList, item)
		}
		v = &sumList
	} else if (isFraction(leftValue) && IsNumber(rightValue)) || (isFraction(rightValue) && IsNumber(leftValue)) {
		if IsFloat(leftValue) || IsFloat(rightValue) {
			v = numAsFloat(leftValue) + numAsFloat(rightValue)
		} else {
			v, err = sumAnyFract(leftValue, rightValue)
		}
	} else if IsDict(leftValue) && IsDict(rightValue) {
		leftDict, _ := leftValue.(*DictType)
		rightDict, _ := rightValue.(*DictType)
		c := leftDict.clone()
		c.merge(rightDict)
		v = c
	} else {
		err = self.errIncompatibleTypes(leftValue, rightValue)
	}
	return
}

//-------- minus term

func newMinusTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priSum,
		evalFunc: evalMinus,
	}
}

func evalMinus(ctx ExprContext, self *term) (v any, err error) {
	var leftValue, rightValue any

	if leftValue, rightValue, err = self.evalInfix(ctx); err != nil {
		return
	}

	if isNumOrFract(leftValue) && isNumOrFract(rightValue) {
		if IsFloat(leftValue) || IsFloat(rightValue) {
			v = numAsFloat(leftValue) - numAsFloat(rightValue)
		} else if isFraction(leftValue) || isFraction(rightValue) {
			v, err = subAnyFract(leftValue, rightValue)
		} else {
			leftInt, _ := leftValue.(int64)
			rightInt, _ := rightValue.(int64)
			v = leftInt - rightInt
		}
	} else if IsList(leftValue) && IsList(rightValue) {
		leftList, _ := leftValue.(*ListType)
		rightList, _ := rightValue.(*ListType)
		diffList := make(ListType, 0, len(*leftList)-len(*rightList))
		for _, item := range *leftList {
			if slices.Index(*rightList, item) < 0 {
				diffList = append(diffList, item)
			}
		}
		v = &diffList
	} else {
		err = self.errIncompatibleTypes(leftValue, rightValue)
	}
	return
}

// init
func init() {
	registerTermConstructor(SymPlus, newPlusTerm)
	registerTermConstructor(SymMinus, newMinusTerm)
}