// expr project scanner.go
package expr

import (
	"bufio"
	"errors"
	"io"
	"strconv"
	"strings"
)

type scanner struct {
	current      *Token
	prev         *Token
	stream       *bufio.Reader
	row          int
	column       int
	translations map[Symbol]Symbol
}

func NewScanner(s io.Reader, translations map[Symbol]Symbol) (inst *scanner) {
	inst = &scanner{
		stream:       bufio.NewReader(s),
		row:          1,
		column:       1,
		translations: translations,
	}
	inst.current = inst.fetchNextToken()
	return inst
}

func DefaultTranslations() map[Symbol]Symbol {
	return map[Symbol]Symbol{
		SymDoubleAmpersand: SymAnd,
		SymKwAnd:           SymAnd,
		SymDoubleVertBar:   SymOr,
		SymKwOr:            SymOr,
		SymKwNot:           SymNot,
		SymLessGreater:     SymNotEqual,
	}
}

// func (self *scanner) Current() *Token {
// 	return self.current
// }

func (self *scanner) readChar() (ch byte, err error) {
	if ch, err = self.stream.ReadByte(); err == nil {
		if ch == '\n' {
			self.row++
			self.column = 0
		} else {
			self.column++
		}
	}
	return
}

func (self *scanner) unreadChar() (err error) {
	if err = self.stream.UnreadByte(); err == nil {
		if self.column--; self.column == 0 {
			if self.row--; self.row == 0 {
				err = errors.New("unread beyond the stream boundary")
			} else {
				self.column = 1
			}
		}
	}
	return
}

func (self *scanner) Previous() *Token {
	return self.prev
}

func (self *scanner) Next() (tk *Token) {
	self.prev = self.current
	tk = self.current
	self.current = self.fetchNextToken()
	return tk
}

func (self *scanner) fetchNextToken() (tk *Token) {
	if err := self.skipBlanks(); err != nil {
		return self.makeErrorToken(err)
	}

	escape := false
	for {
		ch, _ := self.readChar()
		switch ch {
		case '+':
			if next, _ := self.peek(); next == '+' {
				tk = self.moveOn(SymDoublePlus, ch, next)
			} else if next == '=' {
				tk = self.moveOn(SymPlusEqual, ch, next)
			} else {
				tk = self.makeToken(SymPlus, ch)
			}
		case '-':
			if next, _ := self.peek(); next == '-' {
				tk = self.moveOn(SymDoubleMinus, ch, next)
			} else if next == '=' {
				tk = self.moveOn(SymMinusEqual, ch, next)
			} else {
				tk = self.makeToken(SymMinus, ch)
			}
		case '*':
			if next, _ := self.peek(); next == '*' {
				tk = self.moveOn(SymDoubleStar, ch, next)
				// } else if next == '/' {
				// 	tk = self.moveOn(SymClosedComment, ch, next)
			} else {
				tk = self.makeToken(SymStar, ch)
			}
		case '/':
			if next, _ := self.peek(); next == '*' {
				self.readChar()
				tk = self.fetchBlockComment()
			} else if next == '/' {
				self.readChar()
				tk = self.fetchOnLineComment()
			} else {
				tk = self.makeToken(SymSlash, ch)
			}
		case '\\':
			if escape {
				tk = self.makeToken(SymBackSlash, ch)
				escape = false
			} else {
				escape = true
			}
		case '|':
			if next, _ := self.peek(); next == '|' {
				tk = self.moveOn(SymDoubleVertBar, ch, next)
			} else {
				tk = self.makeToken(SymVertBar, ch)
			}
		case ',':
			tk = self.makeToken(SymComma, ch)
		case ':':
			tk = self.makeToken(SymColon, ch)
		case ';':
			tk = self.makeToken(SymSemiColon, ch)
		case '.':
			if next, _ := self.peek(); next >= '0' && next <= '9' {
				tk = self.parseNumber(ch)
			} else {
				tk = self.makeToken(SymDot, ch)
			}
		case '\'':
			tk = self.makeToken(SymQuote, ch)
		case '"':
			if escape {
				tk = self.makeToken(SymDoubleQuote, ch)
				escape = false
			} else {
				tk = self.fetchString()
			}
		case '`':
			tk = self.makeToken(SymBackTick, ch)
		case '!':
			if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymNotEqual, ch, next)
			} else {
				tk = self.makeToken(SymExclamation, ch)
			}
		case '?':
			tk = self.makeToken(SymQuestion, ch)
		case '&':
			if next, _ := self.peek(); next == '&' {
				tk = self.moveOn(SymDoubleAmpersand, ch, next)
			} else {
				tk = self.makeToken(SymAmpersand, ch)
			}
		case '%':
			tk = self.makeToken(SymPercent, ch)
		case '#':
			tk = self.makeToken(SymHash, ch)
		case '@':
			tk = self.makeToken(SymAt, ch)
		case '_':
			tk = self.makeToken(SymUndescore, ch)
		case '=':
			if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymDoubleEqual, ch, next)
			} else {
				tk = self.makeToken(SymEqual, ch)
			}
		case '<':
			if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymLessOrEqual, ch, next)
			} else if next == '>' {
				tk = self.moveOn(SymLessGreater, ch, next)
			} else {
				tk = self.makeToken(SymLess, ch)
			}
		case '>':
			if next, _ := self.peek(); next == '=' {
				tk = self.moveOn(SymGreaterOrEqual, ch, next)
			} else {
				tk = self.makeToken(SymGreater, ch)
			}
		case '$':
			tk = self.makeToken(SymDollar, ch)
		case '(':
			tk = self.makeToken(SymOpenRound, ch)
		case ')':
			tk = self.makeToken(SymClosedRound, ch)
		case '[':
			tk = self.makeToken(SymOpenSquare, ch)
		case ']':
			tk = self.makeToken(SymClosedSquare, ch)
		case '{':
			tk = self.makeToken(SymOpenBrace, ch)
		case '}':
			tk = self.makeToken(SymClosedBrace, ch)
		case 0:
			if escape {
				tk = self.makeErrorToken(errors.New("incomplete escape sequence"))
			}
			escape = false
		default:
			if ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') {
				tk = self.fetchIdentifier(ch)
			} else if ch >= '0' && ch <= '9' {
				tk = self.parseNumber(ch)
			}
		}
		if !escape {
			break
		}
	}
	return
}

func (self *scanner) sync(err error) error {
	if err == nil {
		err = self.unreadChar()
	}
	return err
}

func (self *scanner) parseNumber(firstCh byte) (tk *Token) {
	var err error
	var ch byte
	var sym Symbol = SymInteger
	var value any
	var sb strings.Builder

	for ch = firstCh; err == nil && (ch >= '0' && ch <= '9'); ch, err = self.readChar() {
		sb.WriteByte(ch)
	}
	if ch == '.' {
		sym = SymFloat
		sb.WriteByte(ch)
		ch, err = self.readChar()
		if ch >= '0' && ch <= '9' {
			for ; err == nil && (ch >= '0' && ch <= '9'); ch, err = self.readChar() {
				sb.WriteByte(ch)
			}
		}
	}
	if ch == 'e' || ch == 'E' {
		sym = SymFloat
		sb.WriteByte(ch)
		if ch, err = self.readChar(); err == nil {
			if ch == '+' || ch == '-' {
				sb.WriteByte(ch)
				ch, err = self.readChar()
			}
			if ch >= '0' && ch <= '9' {
				for ; err == nil && (ch >= '0' && ch <= '9'); ch, err = self.readChar() {
					sb.WriteByte(ch)
				}
				//err = self.sync(err)
			} else {
				err = errors.New("expected integer exponent")
			}
		}
		// } else {
		// err = self.sync(err)
	}

	if err != nil && err != io.EOF {
		tk = self.makeErrorToken(err)
	} else {
		err = self.sync(err)
		txt := sb.String()
		if sym == SymFloat {
			value, err = strconv.ParseFloat(txt, 64)
		} else if strings.HasPrefix(txt, "0x") {
			value, err = strconv.ParseInt(txt, 16, 64)
		} else if strings.HasPrefix(txt, "0o") {
			value, err = strconv.ParseInt(txt, 8, 64)
		} else if strings.HasPrefix(txt, "0b") {
			value, err = strconv.ParseInt(txt, 2, 64)
		} else {
			value, err = strconv.ParseInt(txt, 10, 64)
		}
		if err == nil {
			tk = self.makeValueToken(sym, txt, value)
		} else {
			tk = self.makeErrorToken(err)
		}
	}
	return
}

func (self *scanner) fetchIdentifier(firstCh byte) (tk *Token) {
	var err error
	var sb strings.Builder
	for ch := firstCh; err == nil && (ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9')); ch, err = self.readChar() {
		sb.WriteByte(ch)
	}

	if err != nil && err != io.EOF {
		tk = self.makeErrorToken(err)
	} else if err = self.sync(err); err != nil && err != io.EOF {
		tk = self.makeErrorToken(err)
	} else {
		txt := sb.String()
		uptxt := strings.ToUpper(txt)
		if sym, ok := keywords[uptxt]; ok {
			tk = self.makeKeywordToken(sym, uptxt)
		} else if uptxt == `TRUE` {
			tk = self.makeValueToken(SymBool, txt, true)
		} else if uptxt == `FALSE` {
			tk = self.makeValueToken(SymBool, txt, false)
		} else if ch, _ := self.peek(); ch == '(' {
			self.readChar()
			tk = self.makeValueToken(SymFunction, txt+"(", uptxt)
		} else {
			tk = self.makeValueToken(SymIdentifier, txt, txt)
		}
	}

	// if err != nil && err != io.EOF {
	// 	tk = self.makeErrorToken(err)
	// } else if err = self.sync(err); err != nil && err != io.EOF {
	// 	tk = self.makeErrorToken(err)
	// } else {
	// 	txt := sb.String()
	// 	uptxt := strings.ToUpper(txt)
	// 	if sym, ok := keywords[uptxt]; ok {
	// 		tk = self.makeValueToken(sym, txt, "")
	// 	} else {
	// 		tk = self.makeValueToken(SymIdentifier, txt, txt)
	// 	}
	// }
	return
}

func (self *scanner) fetchBlockComment() *Token {
	return self.fetchUntil(SymComment, false, '*', '/')
}

func (self *scanner) fetchOnLineComment() *Token {
	return self.fetchUntil(SymComment, true, '\n')
}

// func (self *scanner) fetchUntil(sym Symbol, allowEos bool, endings ...byte) (tk *Token) {
// 	var err error
// 	var ch byte
// 	var sb strings.Builder
// 	var value string
// 	end := string(endings)
// 	endReached := false
// 	for ch, err = self.readChar(); err == nil && !endReached; {
// 		sb.WriteByte(ch)
// 		if sb.Len() >= len(end) && strings.HasSuffix(sb.String(), end) {
// 			value = sb.String()[0 : sb.Len()-len(end)]
// 			endReached = true
// 		} else {
// 			ch, err = self.readChar()
// 		}
// 	}
// 	if !endReached && allowEos {
// 		value = sb.String()
// 		endReached = true
// 	}

//		if endReached {
//			tk = self.makeValueToken(sym, "", value)
//		} else {
//			tk = self.makeErrorToken(err)
//		}
//		return
//	}

// func (self *scanner) fetchUntil(sym Symbol, allowEos bool, endings ...byte) (tk *Token) {
// 	var err error
// 	var ch byte
// 	var sb strings.Builder
// 	var value string
// 	end := make([]byte, len(endings))
// 	length := 0
// 	endReached := false
// 	for ch, err = self.readChar(); err == nil && !endReached; {
// 		sb.WriteByte(ch)
// 		if length == len(endings) {
// 			for i := 0; i < length-1; i++ {
// 				end[i] = end[i+1]
// 			}
// 			length--
// 		}
// 		end[length] = ch
// 		length++
// 		if bytes.Equal(endings, end) {
// 			value = sb.String()[0 : sb.Len()-len(end)]
// 			endReached = true
// 		} else {
// 			ch, err = self.readChar()
// 		}
// 	}
// 	if !endReached && allowEos {
// 		value = sb.String()
// 		endReached = true
// 	}

// 	if endReached {
// 		tk = self.makeValueToken(sym, "", value)
// 	} else {
// 		tk = self.makeErrorToken(err)
// 	}
// 	return
// }

func (self *scanner) fetchUntil(sym Symbol, allowEos bool, endings ...byte) (tk *Token) {
	var err error
	var ch byte
	var sb strings.Builder
	var value string
	ring := NewByteSlider(len(endings))
	endReached := false
	for ch, err = self.readChar(); err == nil && !endReached; {
		sb.WriteByte(ch)
		ring.PushEnd(ch)
		if ring.Equal(endings) {
			value = sb.String()[0 : sb.Len()-len(endings)]
			endReached = true
		} else {
			ch, err = self.readChar()
		}
	}
	if !endReached && allowEos {
		value = sb.String()
		endReached = true
	}

	if endReached {
		tk = self.makeValueToken(sym, "", value)
	} else {
		tk = self.makeErrorToken(err)
	}
	return
}

func (self *scanner) fetchString() (tk *Token) {
	var err error
	var ch, prev byte
	var sb strings.Builder
	for ch, err = self.readChar(); err == nil; ch, err = self.readChar() {
		if prev == '\\' {
			switch ch {
			case '"':
				sb.WriteByte('"')
			case 'n':
				sb.WriteByte('\n')
			case 'r':
				sb.WriteByte('\r')
			case 't':
				sb.WriteByte('\t')
			case '\\':
				sb.WriteByte('\\')
			default:
				sb.WriteByte(ch)
			}
			prev = 0
		} else if ch == '"' {
			break
		} else {
			prev = ch
			if ch != '\\' {
				sb.WriteByte(ch)
			}
		}
	}
	if err != nil {
		if err == io.EOF {
			tk = self.makeErrorToken(errors.New("missing string termination \""))
		} else {
			tk = self.makeErrorToken(err)
		}
	} else {
		txt := sb.String()
		tk = self.makeValueToken(SymString, `"`+txt+`"`, txt)
	}
	return
}

func (self *scanner) peek() (next byte, err error) {
	var one []byte
	if one, err = self.stream.Peek(1); err == nil {
		next = one[0]
	}
	return
}

func (self *scanner) skipBlanks() (err error) {
	var one []byte
	for one, err = self.stream.Peek(1); err == nil && one[0] <= 32; one, err = self.stream.Peek(1) {
		self.readChar()
	}
	return
}

func (self *scanner) translate(sym Symbol) Symbol {
	if self.translations != nil {
		if translatedSym, ok := self.translations[sym]; ok {
			return translatedSym
		}
	}
	return sym
}

func (self *scanner) moveOn(sym Symbol, chars ...byte) (tk *Token) {
	tk = NewToken(self.translate(sym), string(chars))
	for i := 1; i < len(chars); i++ {
		self.readChar()
	}
	return
}

func (self *scanner) makeToken(sym Symbol, chars ...byte) (tk *Token) {
	tk = NewToken(self.translate(sym), string(chars))
	return
}

func (self *scanner) makeKeywordToken(sym Symbol, upperCaseKeyword string) (tk *Token) {
	tk = NewToken(self.translate(sym), upperCaseKeyword)
	return
}

func (self *scanner) makeValueToken(sym Symbol, source string, value any) (tk *Token) {
	tk = NewValueToken(self.translate(sym), source, value)
	return
}

func (self *scanner) makeErrorToken(err error) *Token {
	return NewErrorToken(self.row, self.column, err)
}