// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// ast.go
package expr

import (
	"strings"
)

//-------- ast

type ast struct {
	forest []*term
	root   *term
}

func NewAst() *ast {
	return &ast{}
}

func (expr *ast) TypeName() string {
	return "Expression"
}

func (expr *ast) ToForest() {
	if expr.root != nil {
		if expr.forest == nil {
			expr.forest = make([]*term, 0)
		}
		expr.forest = append(expr.forest, expr.root)
		expr.root = nil
	}
}

func (expr *ast) String() string {
	var sb strings.Builder
	if expr.root == nil {
		sb.WriteString("(nil)")
	} else {
		expr.root.toString(&sb)
	}
	return sb.String()
}

func (expr *ast) addTokens(tokens ...*Token) (err error) {
	for _, tk := range tokens {
		if err = expr.addToken(tk); err != nil {
			break
		}
	}
	return
}

func (expr *ast) addToken(tk *Token) (err error) {
	_, err = expr.addToken2(tk)
	return
}

func (expr *ast) addToken2(tk *Token) (t *term, err error) {
	if t = newTerm(tk); t != nil {
		err = expr.addTerm(t)
	} else {
		err = tk.Errorf("unexpected token %q", tk.String())
	}
	return
}

func (expr *ast) addTerm(node *term) (err error) {
	if expr.root == nil {
		expr.root = node
	} else {
		expr.root, err = expr.insert(expr.root, node)
	}
	return
}

func (expr *ast) insert(tree, node *term) (root *term, err error) {
	if tree.getPriority() < node.getPriority() {
		root = tree
		if tree.isComplete() {
			var subRoot *term
			last := tree.removeLastChild()
			if subRoot, err = expr.insert(last, node); err == nil {
				subRoot.setParent(tree)
			}
		} else {
			node.setParent(tree)
		}
	} else if !node.isLeaf() {
		root = node
		tree.setParent(node)
	} else {
		err = node.Errorf("two adjacent operators: %q and %q", tree, node)
	}
	return
}

func (expr *ast) Finish() {
	if expr.root == nil && expr.forest != nil && len(expr.forest) >= 1 {
		expr.root = expr.forest[len(expr.forest)-1]
		expr.forest = expr.forest[0 : len(expr.forest)-1]
	}
}

func (expr *ast) Eval(ctx ExprContext) (result any, err error) {
	expr.Finish()

	if expr.root != nil {
		// initDefaultVars(ctx)
		if expr.forest != nil {
			for _, root := range expr.forest {
				if result, err = root.compute(ctx); err == nil {
					ctx.UnsafeSetVar(ControlLastResult, result)
				} else {
					//err = fmt.Errorf("error in expression nr %d: %v", i+1, err)
					break
				}
			}
		}
		if err == nil {
			if result, err = expr.root.compute(ctx); err == nil {
				ctx.UnsafeSetVar(ControlLastResult, result)
			}
		}
		// } else {
		// 	err = errors.New("empty expression")
	}
	return
}