// Copyright (c) 2024 Celestino Amoroso (celestino.amoroso@gmail.com).
// All rights reserved.

// func-string.go
package expr

import (
	"fmt"
	"io"
	"strings"
)

// --- Start of function definitions
func doJoinStr(funcName string, sep string, it Iterator) (result any, err error) {
	var sb strings.Builder
	var v any
	for v, err = it.Next(); err == nil; v, err = it.Next() {
		if it.Index() > 0 {
			sb.WriteString(sep)
		}
		if s, ok := v.(string); ok {
			sb.WriteString(s)
		} else {
			err = errExpectedGot(funcName, typeString, v)
			return
		}
	}
	if err == nil || err == io.EOF {
		err = nil
		result = sb.String()
	}
	return
}

func joinStrFunc(ctx ExprContext, name string, args []any) (result any, err error) {
	if len(args) < 1 {
		return nil, errMissingRequiredParameter(name, paramSeparator)
	}
	if sep, ok := args[0].(string); ok {
		if len(args) == 1 {
			result = ""
		} else if len(args) == 2 {
			if ls, ok := args[1].(*ListType); ok {
				result, err = doJoinStr(name, sep, NewListIterator(ls, nil))
			} else if it, ok := args[1].(Iterator); ok {
				result, err = doJoinStr(name, sep, it)
			} else {
				err = errInvalidParameterValue(name, paramParts, args[1])
			}
		} else {
			result, err = doJoinStr(name, sep, NewArrayIterator(args[1:]))
		}
	} else {
		err = errWrongParamType(name, paramSeparator, typeString, args[0])
	}
	return
}

func subStrFunc(ctx ExprContext, name string, args []any) (result any, err error) {
	var start = 0
	var count = -1
	var source string
	var ok bool

	if len(args) < 1 {
		return nil, errMissingRequiredParameter(name, paramSource)
	}
	if source, ok = args[0].(string); !ok {
		return nil, errWrongParamType(name, paramSource, typeString, args[0])
	}
	if len(args) > 1 {
		if start, err = toInt(args[1], name+"()"); err != nil {
			return
		}
		if len(args) > 2 {
			if count, err = toInt(args[2], name+"()"); err != nil {
				return
			}
		}
		if start < 0 {
			start = len(source) + start
		}
	}
	if count < 0 {
		count = len(source) - start
	}
	end := min(start+count, len(source))
	result = source[start:end]
	return
}

func trimStrFunc(ctx ExprContext, name string, args []any) (result any, err error) {
	var source string
	var ok bool

	if len(args) < 1 {
		return nil, errMissingRequiredParameter(name, paramSource)
	}
	if source, ok = args[0].(string); !ok {
		return nil, errWrongParamType(name, paramSource, typeString, args[0])
	}
	result = strings.TrimSpace(source)
	return
}

func startsWithStrFunc(ctx ExprContext, name string, args []any) (result any, err error) {
	var source string
	var ok bool

	result = false
	if len(args) < 1 {
		return result, errMissingRequiredParameter(name, paramSource)
	}
	if source, ok = args[0].(string); !ok {
		return result, errWrongParamType(name, paramSource, typeString, args[0])
	}
	for i, targetSpec := range args[1:] {
		if target, ok := targetSpec.(string); ok {
			if strings.HasPrefix(source, target) {
				result = true
				break
			}
		} else {
			err = fmt.Errorf("target item nr %d is %T, expected string", i+1, targetSpec)
			break
		}
	}
	return
}

func endsWithStrFunc(ctx ExprContext, name string, args []any) (result any, err error) {
	var source string
	var ok bool

	result = false
	if len(args) < 1 {
		return result, errMissingRequiredParameter(name, paramSource)
	}
	if source, ok = args[0].(string); !ok {
		return result, errWrongParamType(name, paramSource, typeString, args[0])
	}
	for i, targetSpec := range args[1:] {
		if target, ok := targetSpec.(string); ok {
			if strings.HasSuffix(source, target) {
				result = true
				break
			}
		} else {
			err = fmt.Errorf("target item nr %d is %T, expected string", i+1, targetSpec)
			break
		}
	}
	return
}

func splitStrFunc(ctx ExprContext, name string, args []any) (result any, err error) {
	var source, sep string
	var count int = -1
	var parts []string
	var ok bool

	if len(args) < 1 {
		return result, errMissingRequiredParameter(name, paramSource)
	}
	if source, ok = args[0].(string); !ok {
		return result, errWrongParamType(name, paramSource, typeString, args[0])
	}
	if len(args) >= 2 {
		if sep, ok = args[1].(string); !ok {
			return nil, fmt.Errorf("separator param must be string, got %T (%v)", args[1], args[1])
		}
		if len(args) >= 3 {
			if count64, ok := args[2].(int64); ok { // TODO replace type assertion with toInt()
				count = int(count64)
			} else {
				return nil, fmt.Errorf("part count must be integer, got %T (%v)", args[2], args[2])
			}
		}
	}
	if count > 0 {
		parts = strings.SplitN(source, sep, count)
	} else if count < 0 {
		parts = strings.Split(source, sep)
	} else {
		parts = []string{}
	}
	list := make(ListType, len(parts))
	for i, part := range parts {
		list[i] = part
	}
	result = &list
	return
}

// --- End of function definitions

// Import above functions in the context
func ImportStringFuncs(ctx ExprContext) {
	ctx.RegisterFunc("joinStr", &simpleFunctor{f: joinStrFunc}, 1, -1)
	ctx.RegisterFunc("subStr", &simpleFunctor{f: subStrFunc}, 1, -1)
	ctx.RegisterFunc("splitStr", &simpleFunctor{f: splitStrFunc}, 2, -1)
	ctx.RegisterFunc("trimStr", &simpleFunctor{f: trimStrFunc}, 1, -1)
	ctx.RegisterFunc("startsWithStr", &simpleFunctor{f: startsWithStrFunc}, 2, -1)
	ctx.RegisterFunc("endsWithStr", &simpleFunctor{f: endsWithStrFunc}, 2, -1)
}

// Register the import function in the import-register.
// That will allow to import all function of this module by the "builtin" operator."
func init() {
	registerImport("string", ImportStringFuncs, "string utilities")
}