// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserightChilded.

// operator-assign.go
package expr

//-------- assign term

func newAssignTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priAssign,
		evalFunc: evalAssign,
	}
}

func assignCollectionItem(ctx ExprContext, collectionTerm, keyListTerm *term, value any) (err error) {
	var collectionValue, keyListValue, keyValue any
	var keyList *ListType
	var ok bool

	if collectionValue, err = collectionTerm.compute(ctx); err != nil {
		return
	}

	if keyListValue, err = keyListTerm.compute(ctx); err != nil {
		return
	} else if keyList, ok = keyListValue.(*ListType); !ok || len(*keyList) != 1 {
		err = keyListTerm.Errorf("index/key specification expected, got %v [%s]", keyListValue, TypeName(keyListValue))
		return
	}
	if keyValue = (*keyList)[0]; keyValue == nil {
		err = keyListTerm.Errorf("index/key is nil")
		return
	}

	switch collection := collectionValue.(type) {
	case *ListType:
		if index, ok := keyValue.(int64); ok {
			err = collection.setItem(index, value)
		} else {
			err = keyListTerm.Errorf("integer expected, got %v [%s]", keyValue, TypeName(keyValue))
		}
	case *DictType:
		err = collection.setItem(keyValue, value)
	default:
		err = collectionTerm.Errorf("collection expected")
	}
	return
}

func assignValue(ctx ExprContext, leftTerm *term, v any) (err error) {
	if leftTerm.symbol() == SymIndex {
		err = assignCollectionItem(ctx, leftTerm.children[0], leftTerm.children[1], v)
	} else {
		ctx.UnsafeSetVar(leftTerm.source(), v)
	}
	return
}

func evalAssign(ctx ExprContext, opTerm *term) (v any, err error) {
	if err = opTerm.checkOperands(); err != nil {
		return
	}

	leftTerm := opTerm.children[0]
	leftSym := leftTerm.symbol()
	if leftSym != SymVariable && leftSym != SymIndex {
		err = leftTerm.tk.Errorf("left operand of %q must be a variable or a collection's item", opTerm.tk.source)
		return
	}

	rightChild := opTerm.children[1]

	if v, err = rightChild.compute(ctx); err == nil {
		if functor, ok := v.(Functor); ok {
			if leftSym == SymVariable {
				if info := functor.GetFunc(); info != nil {
					ctx.RegisterFunc(leftTerm.source(), info.Functor(), info.ReturnType(), info.Params())
				} else if funcDef, ok := functor.(*exprFunctor); ok {
					paramSpecs := ForAll(funcDef.params, func(p ExprFuncParam) ExprFuncParam { return p })

					ctx.RegisterFunc(leftTerm.source(), functor, TypeAny, paramSpecs)
				} else {
					err = opTerm.Errorf("unknown function %s()", rightChild.source())
				}
			} else {
				err = assignValue(ctx, leftTerm, v)
			}
		} else {
			err = assignValue(ctx, leftTerm, v)
		}
	}
	if err != nil {
		v = nil
	}
	return
}

//-------- assign term

func newOpAssignTerm(tk *Token) (inst *term) {
	return &term{
		tk:       *tk,
		children: make([]*term, 0, 2),
		position: posInfix,
		priority: priAssign,
		evalFunc: evalOpAssign,
	}
}

func getCollectionItemValue(ctx ExprContext, collectionTerm, keyListTerm *term) (value any, err error) {
	var collectionValue, keyListValue, keyValue any
	var keyList *ListType
	var ok bool

	if collectionValue, err = collectionTerm.compute(ctx); err != nil {
		return
	}

	if keyListValue, err = keyListTerm.compute(ctx); err != nil {
		return
	} else if keyList, ok = keyListValue.(*ListType); !ok || len(*keyList) != 1 {
		err = keyListTerm.Errorf("index/key specification expected, got %v [%s]", keyListValue, TypeName(keyListValue))
		return
	}
	if keyValue = (*keyList)[0]; keyValue == nil {
		err = keyListTerm.Errorf("index/key is nil")
		return
	}

	switch collection := collectionValue.(type) {
	case *ListType:
		if index, ok := keyValue.(int64); ok {
			value = (*collection)[index]
		} else {
			err = keyListTerm.Errorf("integer expected, got %v [%s]", keyValue, TypeName(keyValue))
		}
	case *DictType:
		value = (*collection)[keyValue]
	default:
		err = collectionTerm.Errorf("collection expected")
	}
	return
}

func getAssignValue(ctx ExprContext, leftTerm *term) (value any, err error) {
	if leftTerm.symbol() == SymIndex {
		value, err = getCollectionItemValue(ctx, leftTerm.children[0], leftTerm.children[1])
	} else {
		value, _ = ctx.GetVar(leftTerm.source())
	}
	return
}

func evalOpAssign(ctx ExprContext, opTerm *term) (v any, err error) {
	var rightValue, leftValue any
	if err = opTerm.checkOperands(); err != nil {
		return
	}

	leftTerm := opTerm.children[0]
	leftSym := leftTerm.symbol()
	if leftSym != SymVariable && leftSym != SymIndex {
		err = leftTerm.tk.Errorf("left operand of %q must be a variable or a collection's item", opTerm.tk.source)
		return
	}

	rightChild := opTerm.children[1]

	if rightValue, err = rightChild.compute(ctx); err == nil {
		if leftValue, err = getAssignValue(ctx, leftTerm); err == nil {
			switch opTerm.symbol() {
			case SymPlusEqual:
				v, err = sumValues(opTerm, leftValue, rightValue)
			case SymMinusEqual:
				v, err = diffValues(opTerm, leftValue, rightValue)
			case SymStarEqual:
				v, err = mulValues(opTerm, leftValue, rightValue)
			case SymSlashEqual:
				v, err = divValues(opTerm, leftValue, rightValue)
			case SymPercEqual:
				v, err = remainderValues(opTerm, leftValue, rightValue)
			default:
				err = opTerm.Errorf("unsupported assign operator %q", opTerm.source())
			}
			if err == nil {
				err = assignValue(ctx, leftTerm, v)
			}
		}
	}
	return
}

// init
func init() {
	registerTermConstructor(SymEqual, newAssignTerm)
	registerTermConstructor(SymPlusEqual, newOpAssignTerm)
	registerTermConstructor(SymMinusEqual, newOpAssignTerm)
	registerTermConstructor(SymStarEqual, newOpAssignTerm)
	registerTermConstructor(SymSlashEqual, newOpAssignTerm)
}