package testing

import (
	"fmt"
	"sync"

	"github.com/influxdata/flux"
	"github.com/influxdata/flux/codes"
	"github.com/influxdata/flux/execute"
	"github.com/influxdata/flux/internal/errors"
	"github.com/influxdata/flux/memory"
	"github.com/influxdata/flux/plan"
	"github.com/influxdata/flux/runtime"
)

const AssertEqualsKind = "assertEquals"

type AssertEqualsOpSpec struct {
	Name string `json:"name"`
}

func (s *AssertEqualsOpSpec) Kind() flux.OperationKind {
	return AssertEqualsKind
}

func init() {
	assertEqualsSignature := runtime.MustLookupBuiltinType("testing", "assertEquals")

	runtime.RegisterPackageValue("testing", "assertEquals", flux.MustValue(flux.FunctionValue(AssertEqualsKind, createAssertEqualsOpSpec, assertEqualsSignature)))
	flux.RegisterOpSpec(AssertEqualsKind, newAssertEqualsOp)
	plan.RegisterProcedureSpec(AssertEqualsKind, newAssertEqualsProcedure, AssertEqualsKind)
	execute.RegisterTransformation(AssertEqualsKind, createAssertEqualsTransformation)
}

func createAssertEqualsOpSpec(args flux.Arguments, a *flux.Administration) (flux.OperationSpec, error) {
	t, ok := args.Get("got")
	if !ok {
		return nil, errors.New(codes.Invalid, "argument 'got' not present")
	}
	p, ok := t.(*flux.TableObject)
	if !ok {
		return nil, errors.New(codes.Invalid, "got input to assertEquals is not a table object")
	}
	a.AddParent(p)

	t, ok = args.Get("want")
	if !ok {
		return nil, errors.New(codes.Invalid, "argument 'want' not present")
	}
	p, ok = t.(*flux.TableObject)
	if !ok {
		return nil, errors.New(codes.Invalid, "want input to assertEquals is not a table object")
	}
	a.AddParent(p)

	name, err := args.GetRequiredString("name")
	if err != nil {
		return nil, err
	}

	return &AssertEqualsOpSpec{Name: name}, nil
}

func newAssertEqualsOp() flux.OperationSpec {
	return new(AssertEqualsOpSpec)
}

type AssertEqualsProcedureSpec struct {
	plan.DefaultCost
	Name string
}

func (s *AssertEqualsProcedureSpec) Kind() plan.ProcedureKind {
	return AssertEqualsKind
}

func (s *AssertEqualsProcedureSpec) Copy() plan.ProcedureSpec {
	ns := *s
	return &ns
}

func newAssertEqualsProcedure(qs flux.OperationSpec, pa plan.Administration) (plan.ProcedureSpec, error) {
	spec, ok := qs.(*AssertEqualsOpSpec)
	if !ok {
		return nil, errors.Newf(codes.Internal, "invalid spec type %T", qs)
	}
	return &AssertEqualsProcedureSpec{Name: spec.Name}, nil
}

type AssertEqualsTransformation struct {
	execute.ExecutionNode
	mu sync.Mutex

	gotParent   *assertEqualsParentState
	wantParent  *assertEqualsParentState
	keysMatched int
	unequal     bool
	err         error

	d     execute.Dataset
	cache execute.TableBuilderCache
	a     memory.Allocator

	name string
}

type AssertEqualsError struct {
	msg string
}

func (e *AssertEqualsError) Error() string {
	return e.msg
}

func (e *AssertEqualsError) Assertion() bool {
	return true
}

type assertEqualsParentState struct {
	id         execute.DatasetID
	mark       execute.Time
	processing execute.Time
	ntables    int
	finished   bool
}

func createAssertEqualsTransformation(id execute.DatasetID, mode execute.AccumulationMode, spec plan.ProcedureSpec, a execute.Administration) (execute.Transformation, execute.Dataset, error) {
	if len(a.Parents()) != 2 {
		return nil, nil, errors.New(codes.Internal, "assertEquals should have exactly 2 parents")
	}

	cache := execute.NewTableBuilderCache(a.Allocator())
	dataset := execute.NewDataset(id, mode, cache)
	pspec, ok := spec.(*AssertEqualsProcedureSpec)
	if !ok {
		return nil, nil, errors.Newf(codes.Internal, "invalid spec type %T", spec)
	}

	transform := NewAssertEqualsTransformation(dataset, cache, pspec, a.Parents()[0], a.Parents()[1], a.Allocator())

	return transform, dataset, nil
}

func NewAssertEqualsTransformation(d execute.Dataset, cache execute.TableBuilderCache, spec *AssertEqualsProcedureSpec, gotID, wantID execute.DatasetID, a memory.Allocator) *AssertEqualsTransformation {
	return &AssertEqualsTransformation{
		gotParent:   &assertEqualsParentState{id: gotID},
		wantParent:  &assertEqualsParentState{id: wantID},
		keysMatched: 0,
		unequal:     false,
		d:           d,
		cache:       cache,
		name:        spec.Name,
		a:           a,
	}
}

func (t *AssertEqualsTransformation) RetractTable(id execute.DatasetID, key flux.GroupKey) error {
	panic("not implemented")
}

func (t *AssertEqualsTransformation) Process(id execute.DatasetID, tbl flux.Table) error {
	t.mu.Lock()
	defer t.mu.Unlock()
	var colMap = make([]int, 0, len(tbl.Cols()))
	var err error
	builder, created := t.cache.TableBuilder(tbl.Key())
	if id == t.wantParent.id {
		t.wantParent.ntables++
	} else if id == t.gotParent.id {
		t.gotParent.ntables++
	} else {
		return errors.Newf(codes.Internal, "unexpected dataset id: %v", id)
	}
	if created {
		colMap, err = execute.AddNewTableCols(tbl, builder, colMap)
		if err != nil {
			return err
		}
		if err := execute.AppendMappedTable(tbl, builder, colMap); err != nil {
			return err
		}
		t.keysMatched++
	} else {
		t.keysMatched--
		cacheTable, err := builder.Table()
		if err != nil {
			return err
		}
		if ok, err := execute.TablesEqual(cacheTable, tbl, t.a); err != nil {
			return err
		} else if !ok {
			t.unequal = true
			return &AssertEqualsError{fmt.Sprintf("test %s: tables not equal", t.name)}
		}
	}

	return nil
}

func (t *AssertEqualsTransformation) UpdateWatermark(id execute.DatasetID, mark execute.Time) error {
	t.mu.Lock()
	defer t.mu.Unlock()
	min := mark
	if t.gotParent.id == id {
		t.gotParent.mark = mark
		if t.wantParent.mark < min {
			min = t.wantParent.mark
		}
	} else if t.wantParent.id == id {
		t.wantParent.mark = mark
		if t.gotParent.mark < min {
			min = t.gotParent.mark
		}
	} else {
		return errors.Newf(codes.Internal, "unexpected dataset id: %v", id)
	}

	return t.d.UpdateWatermark(min)
}

func (t *AssertEqualsTransformation) UpdateProcessingTime(id execute.DatasetID, pt execute.Time) error {
	t.mu.Lock()
	defer t.mu.Unlock()

	min := pt
	if t.gotParent.id == id {
		t.gotParent.processing = pt
		if t.wantParent.processing < min {
			min = t.wantParent.processing
		}
	} else if t.wantParent.id == id {
		t.wantParent.processing = pt
		if t.gotParent.processing < min {
			min = t.gotParent.processing
		}
	} else {
		return errors.Newf(codes.Internal, "unexpected dataset id: %v", id)
	}
	return t.d.UpdateProcessingTime(min)
}

func (t *AssertEqualsTransformation) Finish(id execute.DatasetID, err error) {
	t.mu.Lock()
	defer t.mu.Unlock()

	if t.gotParent.id == id {
		t.gotParent.finished = true
	} else if t.wantParent.id == id {
		t.wantParent.finished = true
	} else {
		t.d.Finish(errors.Newf(codes.Internal, "unexpected dataset id: %v", id))
	}

	if err != nil {
		t.err = err
	}

	if t.gotParent.finished && t.wantParent.finished {
		if !t.unequal {
			if t.keysMatched > 0 {
				t.err = &AssertEqualsError{fmt.Sprintf("test %s: unequal group key sets", t.name)}
			}

			if t.wantParent.ntables != t.gotParent.ntables {
				t.err = &AssertEqualsError{"assertEquals streams had unequal table counts"}
			}
		}
		t.d.Finish(t.err)
	}
}
