// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// expr project scanner.go
package expr

import (
	"bufio"
	"errors"
	"fmt"
	"io"
	"strconv"
	"strings"
)

type scanner struct {
	current      *Token
	prev         *Token
	stream       *bufio.Reader
	row          int
	column       int
	translations map[Symbol]Symbol
}

func NewScanner(s io.Reader, translations map[Symbol]Symbol) (inst *scanner) {
	inst = &scanner{
		stream:       bufio.NewReader(s),
		row:          1,
		column:       1,
		translations: translations,
	}
	inst.current = inst.fetchNextToken()
	return inst
}

func DefaultTranslations() map[Symbol]Symbol {
	return map[Symbol]Symbol{
		SymDoubleAmpersand: SymAnd,
		SymKwAnd:           SymAnd,
		SymDoubleVertBar:   SymOr,
		SymKwOr:            SymOr,
		SymTilde:           SymNot,
		SymKwNot:           SymNot,
		SymLessGreater:     SymNotEqual,
	}
}

// func (self *scanner) Current() *Token {
// 	return self.current
// }

func (self *scanner) readChar() (ch byte, err error) {
	if ch, err = self.stream.ReadByte(); err == nil {
		if ch == '\n' {
			self.row++
			self.column = 0
		} else {
			self.column++
		}
	}
	return
}

func (self *scanner) unreadChar() (err error) {
	if err = self.stream.UnreadByte(); err == nil {
		if self.column--; self.column == 0 {
			if self.row--; self.row == 0 {
				err = errors.New("unread beyond the stream boundary")
			} else {
				self.column = 1
			}
		}
	}
	return
}

func (self *scanner) Previous() *Token {
	return self.prev
}

func (self *scanner) Next() (tk *Token) {
	self.prev = self.current
	tk = self.current
	self.current = self.fetchNextToken()
	return tk
}

func (self *scanner) fetchNextToken() (tk *Token) {
	var ch byte
	if err := self.skipBlanks(); err != nil {
		return self.makeErrorToken(err)
	}

	escape := false
	for {
		ch, _ = self.readChar()
		switch ch {
		case '+':
			if next, _ := self.peek(); next == '+' {
				tk = self.moveOn(SymDoublePlus, ch, next)
			} else if next == '=' {
				tk = self.moveOn(SymPlusEqual, ch, next)
			} else {
				tk = self.makeToken(SymPlus, ch)
			}
		case '-':
			if next, _ := self.peek(); next == '-' {
				tk = self.moveOn(SymDoubleMinus, ch, next)
			} else if next == '=' {
				tk = self.moveOn(SymMinusEqual, ch, next)
			} else {
				tk = self.makeToken(SymMinus, ch)
			}
		case '*':
			if next, _ := self.peek(); next == '*' {
				tk = self.moveOn(SymDoubleStar, ch, next)
				// } else if next == '/' {
				// 	tk = self.moveOn(SymClosedComment, ch, next)
			} else {
				tk = self.makeToken(SymStar, ch)
			}
		case '/':
			if next, _ := self.peek(); next == '*' {
				self.readChar()
				tk = self.fetchBlockComment()
			} else if next == '/' {
				self.readChar()
				tk = self.fetchOnLineComment()
			} else {
				tk = self.makeToken(SymSlash, ch)
			}
		case '\\':
			if escape {
				tk = self.makeToken(SymBackSlash, ch)
				escape = false
			} else {
				escape = true
			}
		case '|':
			if next, _ := self.peek(); next == '|' {
				tk = self.moveOn(SymDoubleVertBar, ch, next)
			} else {
				tk = self.makeToken(SymVertBar, ch)
			}
		case ',':
			tk = self.makeToken(SymComma, ch)
		case '^':
			tk = self.makeToken(SymCaret, ch)
		case ':':
			if next, _ := self.peek(); next == ':' {
				tk = self.moveOn(SymDoubleColon, ch, next)
			} else {
				tk = self.makeToken(SymColon, ch)
			}
		case ';':
			tk = self.makeToken(SymSemiColon, ch)
		case '.':
			//if next, _ := self.peek(); next >= '0' && next <= '9' {
			//	tk = self.parseNumber(ch)
			//} else if next == '/' {
			if next, _ := self.peek(); next == '/' {
				tk = self.moveOn(SymDotSlash, ch, next)
			} else if next == '.' {
				if next1, _ := self.peek(); next1 == '.' {
					tk = self.moveOn(SymTripleDot, ch, next, next1)
				} else {
					tk = self.moveOn(SymDoubleDot, ch, next)
				}
			} else {
				tk = self.makeToken(SymDot, ch)
			}
		case '\'':
			tk = self.makeToken(SymQuote, ch)
		case '"':
			if escape {
				tk = self.makeToken(SymDoubleQuote, ch)
				escape = false
			} else {
				tk = self.fetchString()
			}
		case '`':
			tk = self.makeToken(SymBackTick, ch)
		case '!':
			if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymNotEqual, ch, next)
			} else {
				tk = self.makeToken(SymExclamation, ch)
			}
		case '?':
			if next, _ := self.peek(); next == '?' {
				tk = self.moveOn(SymDoubleQuestion, ch, next)
			} else if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymQuestionEqual, ch, next)
			} else {
				tk = self.makeToken(SymQuestion, ch)
			}
		case '&':
			if next, _ := self.peek(); next == '&' {
				tk = self.moveOn(SymDoubleAmpersand, ch, next)
			} else {
				tk = self.makeToken(SymAmpersand, ch)
			}
		case '%':
			tk = self.makeToken(SymPercent, ch)
		case '#':
			tk = self.makeToken(SymHash, ch)
		case '@':
			if next, _ := self.peek(); (next >= 'a' && next <= 'z') || (next >= 'A' && next <= 'Z') {
				self.readChar()
				if tk = self.fetchIdentifier(next); tk.Sym == SymIdentifier {
					//tk.Sym = SymIdRef
					tk.source = "@" + tk.source
				} else {
					tk = self.makeErrorToken(fmt.Errorf("invalid variable reference %q", tk.source))
				}
			} else if next == '@' {
				tk = self.moveOn(SymDoubleAt, ch, next)
			} else {
				tk = self.makeToken(SymAt, ch)
			}
		case '_':
			tk = self.makeToken(SymUndescore, ch)
		case '=':
			if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymDoubleEqual, ch, next)
			} else {
				tk = self.makeToken(SymEqual, ch)
			}
		case '<':
			if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymLessOrEqual, ch, next)
			} else if next == '<' {
				tk = self.moveOn(SymAppend, ch, next)
			} else if next == '>' {
				tk = self.moveOn(SymLessGreater, ch, next)
			} else {
				tk = self.makeToken(SymLess, ch)
			}
		case '>':
			if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymGreaterOrEqual, ch, next)
			} else if next == '>' {
				tk = self.moveOn(SymInsert, ch, next)
			} else {
				tk = self.makeToken(SymGreater, ch)
			}
		case '$':
			if next, _ := self.peek(); next == '(' {
				tk = self.moveOn(SymDollarRound, ch, next)
				tk.source += ")"
			} else if next == '$' {
				tk = self.moveOn(SymDoubleDollar, ch, next)
			} else {
				tk = self.makeToken(SymDollar, ch)
			}
		case '(':
			if next, _ := self.peek(); next == ')' {
				tk = self.moveOn(SymOpenClosedRound, ch, next)
			} else {
				tk = self.makeToken(SymOpenRound, ch)
			}
		case ')':
			tk = self.makeToken(SymClosedRound, ch)
		case '[':
			tk = self.makeToken(SymOpenSquare, ch)
		case ']':
			tk = self.makeToken(SymClosedSquare, ch)
		case '{':
			tk = self.makeToken(SymOpenBrace, ch)
		case '}':
			tk = self.makeToken(SymClosedBrace, ch)
		case '~':
			tk = self.makeToken(SymTilde, ch)
		case 0:
			if escape {
				tk = self.makeErrorToken(errors.New("incomplete escape sequence"))
			}
			escape = false
		default:
			if /*ch == '_' ||*/ (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') {
				if tk = self.fetchIdentifier(ch); tk.Sym == SymKwFunc {
					if next, _ := self.peek(); next == '(' {
						tk = self.moveOn(SymFuncDef, ch, next)
					}
				}
			} else if ch >= '0' && ch <= '9' {
				tk = self.parseNumber(ch)
			}
		}
		if !escape {
			break
		}
	}
	if tk == nil {
		tk = NewErrorToken(self.row, self.column, fmt.Errorf("unknown symbol '%c'", ch))
	}
	return
}

func (self *scanner) sync(err error) error {
	if err == nil {
		err = self.unreadChar()
	}
	return err
}

func isBinaryDigit(ch byte) bool {
	return ch == '0' || ch == '1'
}

func isOctalDigit(ch byte) bool {
	return ch >= '0' && ch <= '7'
}

func isDecimalDigit(ch byte) bool {
	return ch >= '0' && ch <= '9'
}

func isHexDigit(ch byte) bool {
	return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F')
}

func (self *scanner) initBase(sb *strings.Builder, currentFirstCh byte) (firstCh byte, numBase int, digitFunc func(byte) bool, err error) {
	var ch byte
	var digitType string
	firstCh = currentFirstCh
	digitFunc = isDecimalDigit
	numBase = 10

	if ch, err = self.peek(); err == nil {
		if ch == 'b' || ch == 'B' {
			numBase = 2
			digitType = "binary"
			self.readChar()
			digitFunc = isBinaryDigit
			firstCh, err = self.readChar()
		} else if ch == 'o' || ch == 'O' {
			numBase = 8
			digitType = "octal"
			self.readChar()
			digitFunc = isOctalDigit
			firstCh, err = self.readChar()
		} else if ch == 'x' || ch == 'X' {
			numBase = 16
			digitType = "hex"
			self.readChar()
			digitFunc = isHexDigit
			firstCh, err = self.readChar()
		}
		if err == nil && !digitFunc(firstCh) {
			if len(digitType) == 0 {
				digitType = "decimal"
			}
			err = fmt.Errorf("expected %s digit, got '%c'", digitType, firstCh)
		}
	} else if err == io.EOF {
		err = nil
	}
	return
}

func (self *scanner) parseNumber(firstCh byte) (tk *Token) {
	var err error
	var ch byte
	var sym Symbol = SymInteger
	var sb strings.Builder
	var isDigit func(byte) bool = isDecimalDigit
	var numBase = 10

	if firstCh == '0' {
		firstCh, numBase, isDigit, err = self.initBase(&sb, firstCh)
	}
	for ch = firstCh; err == nil && isDigit(ch); ch, err = self.readChar() {
		sb.WriteByte(ch)
	}

	if numBase == 10 {
		if err == nil && ch == '.' {
			sym = SymFloat
			sb.WriteByte(ch)
			ch, err = self.readChar()
			if ch >= '0' && ch <= '9' {
				for ; err == nil && (ch >= '0' && ch <= '9'); ch, err = self.readChar() {
					sb.WriteByte(ch)
				}
			}
		}
		if err == nil && (ch == 'e' || ch == 'E') {
			sym = SymFloat
			sb.WriteByte(ch)
			if ch, err = self.readChar(); err == nil {
				if ch == '+' || ch == '-' {
					sb.WriteByte(ch)
					ch, err = self.readChar()
				}
				if ch >= '0' && ch <= '9' {
					for ; err == nil && (ch >= '0' && ch <= '9'); ch, err = self.readChar() {
						sb.WriteByte(ch)
					}
				} else {
					err = errors.New("expected integer exponent")
				}
			}
		}
	}

	if err != nil && err != io.EOF {
		tk = self.makeErrorToken(err)
	} else {
		var value any
		err = self.sync(err)
		txt := sb.String()
		if sym == SymFloat {
			value, err = strconv.ParseFloat(txt, 64)
		} else {
			value, err = strconv.ParseInt(txt, numBase, 64)
		}
		if err == nil {
			tk = self.makeValueToken(sym, txt, value)
		} else {
			tk = self.makeErrorToken(err)
		}
	}
	return
}

func (self *scanner) fetchIdentifier(firstCh byte) (tk *Token) {
	var err error
	var sb strings.Builder
	for ch := firstCh; err == nil && (ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9')); ch, err = self.readChar() {
		sb.WriteByte(ch)
	}

	if err != nil && err != io.EOF {
		tk = self.makeErrorToken(err)
	} else if err = self.sync(err); err != nil && err != io.EOF {
		tk = self.makeErrorToken(err)
	} else {
		txt := sb.String()
		uptxt := strings.ToUpper(txt)
		if sym, ok := keywords[uptxt]; ok {
			tk = self.makeKeywordToken(sym, uptxt)
		} else if uptxt == `TRUE` {
			tk = self.makeValueToken(SymBool, txt, true)
		} else if uptxt == `FALSE` {
			tk = self.makeValueToken(SymBool, txt, false)
		} else if ch, _ := self.peek(); ch == '(' {
			self.readChar()
			tk = self.makeValueToken(SymFuncCall, txt+"(", txt)
		} else {
			tk = self.makeValueToken(SymIdentifier, txt, txt)
		}
	}

	// if err != nil && err != io.EOF {
	// 	tk = self.makeErrorToken(err)
	// } else if err = self.sync(err); err != nil && err != io.EOF {
	// 	tk = self.makeErrorToken(err)
	// } else {
	// 	txt := sb.String()
	// 	uptxt := strings.ToUpper(txt)
	// 	if sym, ok := keywords[uptxt]; ok {
	// 		tk = self.makeValueToken(sym, txt, "")
	// 	} else {
	// 		tk = self.makeValueToken(SymIdentifier, txt, txt)
	// 	}
	// }
	return
}

func (self *scanner) fetchBlockComment() *Token {
	return self.fetchUntil(SymComment, false, '*', '/')
}

func (self *scanner) fetchOnLineComment() *Token {
	return self.fetchUntil(SymComment, true, '\n')
}

func (self *scanner) fetchUntil(sym Symbol, allowEos bool, endings ...byte) (tk *Token) {
	var err error
	var ch byte
	var sb strings.Builder
	var value string
	ring := NewByteSlider(len(endings))
	endReached := false
	for ch, err = self.readChar(); err == nil && !endReached; {
		sb.WriteByte(ch)
		ring.PushEnd(ch)
		if ring.Equal(endings) {
			value = sb.String()[0 : sb.Len()-len(endings)]
			endReached = true
		} else {
			ch, err = self.readChar()
		}
	}
	if !endReached && allowEos {
		value = sb.String()
		endReached = true
	}

	if endReached {
		tk = self.makeValueToken(sym, "", value)
	} else {
		tk = self.makeErrorToken(err)
	}
	return
}

func (self *scanner) fetchString() (tk *Token) {
	var err error
	var ch, prev byte
	var sb strings.Builder
	for ch, err = self.readChar(); err == nil; ch, err = self.readChar() {
		if prev == '\\' {
			switch ch {
			case '"':
				sb.WriteByte('"')
			case 'n':
				sb.WriteByte('\n')
			case 'r':
				sb.WriteByte('\r')
			case 't':
				sb.WriteByte('\t')
			case '\\':
				sb.WriteByte('\\')
			default:
				sb.WriteByte(ch)
			}
			prev = 0
		} else if ch == '"' {
			break
		} else {
			prev = ch
			if ch != '\\' {
				sb.WriteByte(ch)
			}
		}
	}
	if err != nil {
		if err == io.EOF {
			tk = self.makeErrorToken(errors.New("missing string termination \""))
		} else {
			tk = self.makeErrorToken(err)
		}
	} else {
		txt := sb.String()
		tk = self.makeValueToken(SymString, `"`+txt+`"`, txt)
	}
	return
}

func (self *scanner) peek() (next byte, err error) {
	var one []byte
	if one, err = self.stream.Peek(1); err == nil {
		next = one[0]
	}
	return
}

func (self *scanner) skipBlanks() (err error) {
	var one []byte
	for one, err = self.stream.Peek(1); err == nil && one[0] <= 32; one, err = self.stream.Peek(1) {
		self.readChar()
	}
	return
}

func (self *scanner) translate(sym Symbol) Symbol {
	if self.translations != nil {
		if translatedSym, ok := self.translations[sym]; ok {
			return translatedSym
		}
	}
	return sym
}

func (self *scanner) moveOn(sym Symbol, chars ...byte) (tk *Token) {
	tk = NewToken(self.row, self.column, self.translate(sym), string(chars))
	for i := 1; i < len(chars); i++ {
		self.readChar()
	}
	return
}

func (self *scanner) makeToken(sym Symbol, chars ...byte) (tk *Token) {
	tk = NewToken(self.row, self.column, self.translate(sym), string(chars))
	return
}

func (self *scanner) makeKeywordToken(sym Symbol, upperCaseKeyword string) (tk *Token) {
	tk = NewToken(self.row, self.column, self.translate(sym), upperCaseKeyword)
	return
}

func (self *scanner) makeValueToken(sym Symbol, source string, value any) (tk *Token) {
	tk = NewValueToken(self.row, self.column, self.translate(sym), source, value)
	return
}

func (self *scanner) makeErrorToken(err error) *Token {
	return NewErrorToken(self.row, self.column, err)
}