// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// parser.go
package expr

import (
	"errors"

	"golang.org/x/exp/constraints"
)

//-------- parser

type parserContext uint16

const (
	parserNoFlags                = 0
	allowMultiExpr parserContext = 1 << iota
	allowVarRef
	selectorContext
	listContext                                // squareContext for list
	indexContext                               // squareContext for index
	allowIndex                                 // allow index in squareContext
	squareContext = listContext | indexContext // Square parenthesis for list or index
)

func hasFlag[T constraints.Unsigned](set T, singleFlag T) bool {
	return (set & singleFlag) != 0
}

func addFlags[T constraints.Unsigned](set T, flags T) T {
	return set | flags
}

func addFlagsCond[T constraints.Unsigned](set T, flags T, cond bool) (newSet T) {
	if cond {
		newSet = set | flags
	} else {
		newSet = set
	}
	return
}

func remFlags[T constraints.Unsigned](set T, flags T) T {
	return set & (^flags)
}

type parser struct {
}

func NewParser() (p *parser) {
	p = &parser{}
	return p
}

func (parser *parser) Next(scanner *scanner) (tk *Token) {
	for tk = scanner.Next(); tk.IsSymbol(SymComment); tk = scanner.Next() {
	}
	return
}

func (parser *parser) parseFuncCall(scanner *scanner, ctx parserContext, tk *Token) (tree *term, err error) {
	args := make([]*term, 0, 10)
	itemExpected := false
	lastSym := SymUnknown
	for lastSym != SymClosedRound && lastSym != SymEos {
		var subTree *ast
		if subTree, err = parser.parseItem(scanner, ctx, SymComma, SymClosedRound); err != nil {
			break
		}
		prev := scanner.Previous()
		if subTree.root != nil {
			args = append(args, subTree.root)
		} else if itemExpected {
			err = prev.ErrorExpectedGot("function-param-value")
			break
		}

		itemExpected = prev.Sym == SymComma
		lastSym = scanner.Previous().Sym
	}
	if err == nil {
		if lastSym != SymClosedRound {
			err = errors.New("unterminated arguments list")
		} else {
			tree = newFuncCallTerm(tk, args)
		}
	}
	return
}

func (parser *parser) parseFuncDef(scanner *scanner) (tree *term, err error) {
	// Example: "add = func(x,y) {x+y}
	var body *ast
	args := make([]*term, 0)
	lastSym := SymUnknown
	defaultParamsStarted := false
	itemExpected := false
	tk := scanner.Previous()
	for lastSym != SymClosedRound && lastSym != SymEos {
		tk = parser.Next(scanner)
		if tk.IsSymbol(SymIdentifier) {
			param := newTerm(tk)
			if len(args) > 0 {
				if pos := paramAlreadyDefined(args, param); pos > 0 {
					err = tk.Errorf("parameter %q at position %d already defined at position %d", param.source(), len(args)+1, pos)
					break
				}
			}
			args = append(args, param)
			tk = parser.Next(scanner)
			if tk.Sym == SymEqual {
				var paramExpr *ast
				defaultParamsStarted = true
				if paramExpr, err = parser.parseItem(scanner, parserNoFlags, SymComma, SymClosedRound); err != nil {
					break
				}
				param.forceChild(paramExpr.root)
			} else if defaultParamsStarted {
				err = tk.Errorf("can't mix default and non-default parameters")
				break
			}
		} else if itemExpected {
			prev := scanner.Previous()
			err = prev.ErrorExpectedGot("function-param-spec")
			break
		}
		lastSym = scanner.Previous().Sym
		itemExpected = lastSym == SymComma
	}

	if err == nil && lastSym != SymClosedRound {
		err = tk.ErrorExpectedGot(")")
	}
	if err == nil {
		tk = parser.Next(scanner)
		if tk.IsSymbol(SymOpenBrace) {
			body, err = parser.parseGeneral(scanner, allowMultiExpr|allowVarRef, SymClosedBrace)
		} else {
			err = tk.ErrorExpectedGot("{")
		}
	}
	if err == nil {
		if scanner.Previous().Sym != SymClosedBrace {
			err = scanner.Previous().ErrorExpectedGot("}")
		} else {
			tk = scanner.makeValueToken(SymExpression, "", body)
			tree = newFuncDefTerm(tk, args)
		}
	}
	return
}

func paramAlreadyDefined(args []*term, param *term) (position int) {
	position = 0
	for i, arg := range args {
		if arg.source() == param.source() {
			position = i + 1
		}
	}
	return
}

func (parser *parser) parseList(scanner *scanner, ctx parserContext) (listTerm *term, err error) {
	r, c := scanner.lastPos()
	args := make([]*term, 0)
	lastSym := SymUnknown
	itemExpected := false
	itemCtx := remFlags(ctx, allowIndex)
	for lastSym != SymClosedSquare && lastSym != SymEos {
		zeroRequired := scanner.current.Sym == SymColon
		var itemTree *ast
		if itemTree, err = parser.parseItem(scanner, itemCtx, SymComma, SymClosedSquare); err == nil {
			root := itemTree.root
			if root != nil {
				if hasFlag(ctx, allowIndex) && root.symbol() == SymColon {
					changeColonToRange(root)
				}
				if !hasFlag(ctx, allowIndex) && root.symbol() == SymRange {
					// err = root.Errorf("unexpected range expression")
					err = errRangeUnexpectedExpression(root)
					break
				}
				args = append(args, root)
				if hasFlag(ctx, allowIndex) && root.symbol() == SymRange && zeroRequired { //len(root.children) == 0 {
					if len(root.children) == 1 {
						root.children = append(root.children, root.children[0])
					} else if len(root.children) > 1 {
						// err = root.Errorf("invalid range specification")
						err = errRangeInvalidSpecification(root)
						break
					}
					zeroTk := NewValueToken(root.tk.row, root.tk.col, SymInteger, "0", int64(0))
					zeroTerm := newTerm(zeroTk)
					zeroTerm.setParent(root)
					root.children[0] = zeroTerm
				}
			} else if itemExpected {
				prev := scanner.Previous()
				err = prev.ErrorExpectedGot("list-item")
				break
			}
		} else {
			break
		}
		lastSym = scanner.Previous().Sym
		if itemExpected = lastSym == SymComma; itemExpected {
			remFlags(ctx, allowIndex)
		}
	}
	if err == nil {
		if lastSym != SymClosedSquare {
			err = scanner.Previous().ErrorExpectedGot("]")
		} else {
			listTerm = newListTerm(r, c, args)
		}
	}
	return
}

func (parser *parser) parseIterDef(scanner *scanner, ctx parserContext) (subtree *term, err error) {
	tk := scanner.Previous()
	args := make([]*term, 0)
	lastSym := SymUnknown
	itemExpected := false
	for lastSym != SymClosedRound && lastSym != SymEos {
		var subTree *ast
		if subTree, err = parser.parseItem(scanner, ctx, SymComma, SymClosedRound); err == nil {
			if subTree.root != nil {
				args = append(args, subTree.root)
			} else if itemExpected {
				prev := scanner.Previous()
				err = prev.ErrorExpectedGot("iterator-param")
				break
			}
		} else {
			break
		}
		lastSym = scanner.Previous().Sym
		itemExpected = lastSym == SymComma
	}
	if err == nil {
		if lastSym != SymClosedRound {
			err = scanner.Previous().ErrorExpectedGot(")")
		} else {
			subtree = newIteratorTerm(tk, args)
		}
	}
	return
}

func (parser *parser) parseDictKey(scanner *scanner) (key any, err error) {
	tk := parser.Next(scanner)
	if tk.Sym == SymError {
		err = tk.Error()
		return
	}
	if tk.Sym == SymClosedBrace || tk.Sym == SymEos {
		return
	}
	if tk.Sym == SymInteger || tk.Sym == SymString {
		tkSep := parser.Next(scanner)
		if tkSep.Sym != SymColon {
			err = tkSep.ErrorExpectedGot(":")
		} else {
			key = tk.Value
		}
	} else {
		err = tk.ErrorExpectedGot("dictionary-key or }")
	}
	return
}

func (parser *parser) parseDictionary(scanner *scanner, ctx parserContext) (subtree *term, err error) {
	args := make(map[any]*term, 0)
	lastSym := SymUnknown
	itemExpected := false
	for lastSym != SymClosedBrace && lastSym != SymEos {
		var subTree *ast
		var key any
		if key, err = parser.parseDictKey(scanner); err != nil {
			break
		} else if key == nil {
			tk := scanner.Previous()
			lastSym = tk.Sym
			if itemExpected {
				err = tk.ErrorExpectedGot("dictionary-key")
			}
			break
		}
		if subTree, err = parser.parseItem(scanner, ctx, SymComma, SymClosedBrace); err == nil {
			if subTree.root != nil {
				args[key] = subTree.root
			} else /*if key != nil*/ {
				prev := scanner.Previous()
				err = prev.ErrorExpectedGot("dictionary-value")
				break
			}
		} else {
			break
		}
		lastSym = scanner.Previous().Sym
		itemExpected = lastSym == SymComma
	}
	if err == nil {
		if lastSym != SymClosedBrace {
			err = scanner.Previous().ErrorExpectedGot("}")
		} else {
			subtree = newDictTerm(args)
		}
	}
	return
}

func (parser *parser) parseSelectorCase(scanner *scanner, ctx parserContext, defaultCase bool) (caseTerm *term, err error) {
	var filterList *term
	var caseExpr *ast
	ctx = remFlags(ctx, allowIndex)
	tk := parser.Next(scanner)
	startRow := tk.row
	startCol := tk.col
	if tk.Sym == SymOpenSquare {
		if defaultCase {
			err = tk.Errorf("case list in default clause")
			return
		}
		if filterList, err = parser.parseList(scanner, remFlags(ctx, allowIndex)); err != nil {
			return
		}
		tk = parser.Next(scanner)
		startRow = tk.row
		startCol = tk.col
	} else if !defaultCase {
		filterList = newListTerm(startRow, startCol, make([]*term, 0))
	}

	if tk.Sym == SymOpenBrace {
		if caseExpr, err = parser.parseGeneral(scanner, ctx|allowMultiExpr, SymClosedBrace); err != nil {
			return
		}
	} else {
		err = tk.ErrorExpectedGot("{")
	}

	if err == nil {
		caseTerm = newSelectorCaseTerm(startRow, startCol, filterList, caseExpr)
	}
	return
}

func addSelectorCase(selectorTerm, caseTerm *term) {
	if len(selectorTerm.children) < 2 {
		caseListTerm := newListTermA(caseTerm)
		selectorTerm.children = append(selectorTerm.children, caseListTerm)
	} else {
		caseListTerm := selectorTerm.children[1]
		caseList, _ := caseListTerm.value().([]*term)
		caseList = append(caseList, caseTerm)
		caseListTerm.tk.Value = caseList
	}
	caseTerm.parent = selectorTerm
}

func (parser *parser) parseSelector(scanner *scanner, tree *ast, ctx parserContext) (selectorTerm *term, err error) {
	var caseTerm *term

	ctx = remFlags(ctx, allowIndex)
	tk := scanner.makeToken(SymSelector, '?')
	if selectorTerm, err = tree.addToken(tk); err != nil {
		return
	}

	if caseTerm, err = parser.parseSelectorCase(scanner, ctx|allowVarRef, false); err == nil {
		addSelectorCase(selectorTerm, caseTerm)
	}
	return
}

func (parser *parser) parseItem(scanner *scanner, ctx parserContext, termSymbols ...Symbol) (tree *ast, err error) {
	return parser.parseGeneral(scanner, ctx|allowVarRef, termSymbols...)
}

func (parser *parser) Parse(scanner *scanner, termSymbols ...Symbol) (tree *ast, err error) {
	termSymbols = append(termSymbols, SymEos)
	return parser.parseGeneral(scanner, allowMultiExpr, termSymbols...)
}

func couldBeACollection(t *term) bool {
	var sym = SymUnknown
	if t != nil {
		sym = t.symbol()
	}
	return sym == SymList || sym == SymString || sym == SymDict || sym == SymExpression || sym == SymVariable
}

func listSubTree(tree *ast, listTerm *term, allowIndeces bool) (root *term, err error) {
	var tk *Token
	if allowIndeces {
		tk = NewToken(listTerm.tk.row, listTerm.tk.col, SymIndex, listTerm.source())
		root = newTerm(tk)
		if err = tree.addTerm(root); err == nil {
			err = tree.addTerm(listTerm)
		}
	} else {
		root = listTerm
		err = tree.addTerm(listTerm)
	}
	return
}

func (parser *parser) parseGeneral(scanner *scanner, ctx parserContext, termSymbols ...Symbol) (tree *ast, err error) {
	var selectorTerm *term = nil
	var currentTerm *term = nil
	var tk *Token

	tree = NewAst()
	firstToken := true
	// lastSym := SymUnknown
	for tk = parser.Next(scanner); err == nil && tk != nil && !tk.IsTerm(termSymbols); /*&& !areSymbolsOutOfCtx(tk, selectorTerm, SymColon, SymDoubleColon)*/ tk = parser.Next(scanner) {
		// if tk.Sym == SymComment {
		// 	continue
		// }

		if tk.Sym == SymSemiColon {
			if hasFlag(ctx, allowMultiExpr) {
				tree.ToForest()
				firstToken = true
				currentTerm = nil
				selectorTerm = nil
				continue
			} else {
				err = tk.Errorf(`unexpected token %q, expected ",", "]", or ")"`, tk.source)
				break
			}
		}

		//fmt.Println("Token:", tk)
		if firstToken {
			if tk.Sym == SymMinus {
				tk.Sym = SymChangeSign
			} else if tk.Sym == SymPlus {
				tk.Sym = SymUnchangeSign
			} else if tk.IsSymbol(SymExclamation) {
				err = tk.Errorf("postfix opertor %q requires an operand on its left", tk)
				break
			}
			firstToken = false
		}

		switch tk.Sym {
		case SymOpenRound:
			var subTree *ast
			if subTree, err = parser.parseGeneral(scanner, ctx, SymClosedRound); err == nil {
				subTree.root.priority = priValue
				err = tree.addTerm(newExprTerm(subTree.root))
				currentTerm = subTree.root
			}
		case SymFuncCall:
			var funcCallTerm *term
			if funcCallTerm, err = parser.parseFuncCall(scanner, ctx, tk); err == nil {
				err = tree.addTerm(funcCallTerm)
				currentTerm = funcCallTerm
			}
		case SymOpenSquare:
			var listTerm *term
			newCtx := addFlagsCond(addFlags(ctx, squareContext), allowIndex, couldBeACollection(currentTerm))
			if listTerm, err = parser.parseList(scanner, newCtx); err == nil {
				currentTerm, err = listSubTree(tree, listTerm, hasFlag(newCtx, allowIndex))
			}
		case SymOpenBrace:
			if currentTerm != nil && currentTerm.symbol() == SymColon {
				err = currentTerm.Errorf(`selector-case outside of a selector context`)
			} else {
				var mapTerm *term
				if mapTerm, err = parser.parseDictionary(scanner, ctx); err == nil {
					err = tree.addTerm(mapTerm)
					currentTerm = mapTerm
				}
			}
		case SymEqual, SymPlusEqual, SymMinusEqual, SymStarEqual, SymSlashEqual, SymPercEqual:
			currentTerm, err = tree.addToken(tk)
			firstToken = true
		case SymFuncDef:
			var funcDefTerm *term
			if funcDefTerm, err = parser.parseFuncDef(scanner); err == nil {
				err = tree.addTerm(funcDefTerm)
				currentTerm = funcDefTerm
			}
		case SymDollarRound:
			var iterDefTerm *term
			if iterDefTerm, err = parser.parseIterDef(scanner, ctx); err == nil {
				err = tree.addTerm(iterDefTerm)
				currentTerm = iterDefTerm
			}
		case SymIdentifier:
			if tk.source[0] == '@' && !hasFlag(ctx, allowVarRef) {
				err = tk.Errorf("variable references are not allowed in top level expressions: %q", tk.source)
			} else {
				currentTerm, err = tree.addToken(tk)
			}
		case SymQuestion:
			if selectorTerm, err = parser.parseSelector(scanner, tree, ctx); err == nil {
				currentTerm = selectorTerm
				addFlags(ctx, selectorContext)
			}
		case SymColon, SymDoubleColon:
			var caseTerm *term
			if selectorTerm != nil {
				if caseTerm, err = parser.parseSelectorCase(scanner, ctx, tk.Sym == SymDoubleColon); err == nil {
					addSelectorCase(selectorTerm, caseTerm)
					currentTerm = caseTerm
					if tk.Sym == SymDoubleColon {
						selectorTerm = nil
					}
				}
			} else {
				// if hasFlag(ctx, allowIndex) {
				// 	tk.Sym = SymRange
				// }
				currentTerm, err = tree.addToken(tk)
			}
			if tk.IsOneOfA(SymColon, SymRange) {
				// Colon outside a selector term acts like a separator
				firstToken = true
			}
		default:
			currentTerm, err = tree.addToken(tk)
		}

		if currentTerm != nil && currentTerm.tk.Sym != SymSelector && currentTerm.parent != nil && currentTerm.parent.tk.Sym != SymSelector {
			selectorTerm = nil
			remFlags(ctx, selectorContext)
		}
		// lastSym = tk.Sym
	}

	if err == nil {
		if !tk.IsOneOf(termSymbols) {
			var symDesc string
			if tk.IsSymbol(SymError) {
				symDesc = tk.ErrorText()
			} else {
				symDesc = SymToString(tk.Sym)
			}
			err = tk.ErrorExpectedGotStringWithPrefix("expected one of", SymListToString(termSymbols, true), symDesc)
		} else {
			err = tk.Error()
		}
	}
	return
}