package main

import (
	"bytes"
	"fmt"
	"log"
	"os"
	"sort"
	"strconv"
	"strings"

	"github.com/lestrrat-go/codegen"
	"github.com/pkg/errors"
)

func main() {
	if err := _main(); err != nil {
		log.Printf("%s", err)
		os.Exit(1)
	}
}

func _main() error {
	return generateHeaders()
}

type headerField struct {
	name      string
	method    string
	typ       string
	key       string
	comment   string
	hasAccept bool
	jsonTag   string
}

func (f headerField) IsPointer() bool {
	return strings.HasPrefix(f.typ, "*")
}

func (f headerField) PointerElem() string {
	return strings.TrimPrefix(f.typ, "*")
}

var zerovals = map[string]string{
	"string":                 `""`,
	"jwa.SignatureAlgorithm": `""`,
	"[]string":               "0",
}

func zeroval(s string) string {
	if v, ok := zerovals[s]; ok {
		return v
	}
	return "nil"
}

func fieldStorageType(s string) string {
	if fieldStorageTypeIsIndirect(s) {
		return `*` + s
	}
	return s
}

func fieldStorageTypeIsIndirect(s string) bool {
	return !(s == "jwk.Key" || strings.HasPrefix(s, `*`) || strings.HasPrefix(s, `[]`))
}

func generateHeaders() error {
	const jwkKey = "jwk"

	fields := []headerField{
		{
			name:      `algorithm`,
			method:    `Algorithm`,
			typ:       `jwa.SignatureAlgorithm`,
			key:       `alg`,
			comment:   `https://tools.ietf.org/html/rfc7515#section-4.1.1`,
			hasAccept: true,
			jsonTag:   "`" + `json:"alg,omitempty"` + "`",
		},
		{
			name:    `contentType`,
			method:  `ContentType`,
			typ:     `string`,
			key:     `cty`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.10`,
			jsonTag: "`" + `json:"cty,omitempty"` + "`",
		},
		{
			name:    `critical`,
			method:  `Critical`,
			typ:     `[]string`,
			key:     `crit`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.11`,
			jsonTag: "`" + `json:"crit,omitempty"` + "`",
		},
		{
			name:    `jwk`,
			method:  `JWK`,
			typ:     `jwk.Key`,
			key:     `jwk`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.3`,
			jsonTag: "`" + `json:"jwk,omitempty"` + "`",
		},
		{
			name:    `jwkSetURL`,
			method:  `JWKSetURL`,
			typ:     `string`,
			key:     `jku`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.2`,
			jsonTag: "`" + `json:"jku,omitempty"` + "`",
		},
		{
			name:    `keyID`,
			method:  `KeyID`,
			typ:     `string`,
			key:     `kid`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.4`,
			jsonTag: "`" + `json:"kid,omitempty"` + "`",
		},
		{
			name:    `typ`,
			method:  `Type`,
			typ:     `string`,
			key:     `typ`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.9`,
			jsonTag: "`" + `json:"typ,omitempty"` + "`",
		},
		{
			name:    `x509CertChain`,
			method:  `X509CertChain`,
			typ:     `[]string`,
			key:     `x5c`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.6`,
			jsonTag: "`" + `json:"x5c,omitempty"` + "`",
		},
		{
			name:    `x509CertThumbprint`,
			method:  `X509CertThumbprint`,
			typ:     `string`,
			key:     `x5t`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.7`,
			jsonTag: "`" + `json:"x5t,omitempty"` + "`",
		},
		{
			name:    `x509CertThumbprintS256`,
			method:  `X509CertThumbprintS256`,
			typ:     `string`,
			key:     `x5t#S256`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.8`,
			jsonTag: "`" + `json:"x5t#S256,omitempty"` + "`",
		},
		{
			name:    `x509URL`,
			method:  `X509URL`,
			typ:     `string`,
			key:     `x5u`,
			comment: `https://tools.ietf.org/html/rfc7515#section-4.1.5`,
			jsonTag: "`" + `json:"x5u,omitempty"` + "`",
		},
	}

	sort.Slice(fields, func(i, j int) bool {
		return fields[i].name < fields[j].name
	})

	var buf bytes.Buffer

	o := codegen.NewOutput(&buf)
	o.L("// This file is auto-generated by internal/cmd/genheaders/main.go. DO NOT EDIT")
	o.LL("package jws")

	o.LL("import (")
	pkgs := []string{
		"bytes",
		"context",
		"github.com/lestrrat-go/jwx/internal/json",
		"fmt",
		"sort",
		"strconv",
		"github.com/lestrrat-go/jwx/jwa",
		"github.com/lestrrat-go/jwx/jwk",
		"github.com/pkg/errors",
	}
	for _, pkg := range pkgs {
		o.L("%s", strconv.Quote(pkg))
	}
	o.L(")")

	o.LL("const (")
	for _, f := range fields {
		o.L("%sKey = %s", f.method, strconv.Quote(f.key))
	}
	o.L(")") // end const

	o.LL("// Headers describe a standard Header set.")
	o.L("type Headers interface {")
	o.L("json.Marshaler")
	o.L("json.Unmarshaler")
	// These are the basic values that most jws have
	for _, f := range fields {
		o.L("%s() %s", f.method, f.PointerElem())
	}

	// These are used to iterate through all keys in a header
	o.L("Iterate(ctx context.Context) Iterator")
	o.L("Walk(context.Context, Visitor) error")
	o.L("AsMap(context.Context) (map[string]interface{}, error)")
	o.L("Copy(context.Context, Headers) error")
	o.L("Merge(context.Context, Headers) (Headers, error)")

	// These are used to access a single element by key name
	o.L("Get(string) (interface{}, bool)")
	o.L("Set(string, interface{}) error")
	o.L("Remove(string) error")

	o.LL("// PrivateParams returns the non-standard elements in the source structure")
	o.L("// WARNING: DO NOT USE PrivateParams() IF YOU HAVE CONCURRENT CODE ACCESSING THEM.")
	o.L("// Use AsMap() to get a copy of the entire header instead")
	o.L("PrivateParams() map[string]interface{}")

	o.L("}")

	o.LL("type stdHeaders struct {")
	for _, f := range fields {
		o.L("%s %s // %s", f.name, fieldStorageType(f.typ), f.comment)
	}
	o.L("privateParams map[string]interface{}")
	o.L("mu *sync.RWMutex")
	o.L("}") // end type StandardHeaders

	// Proxy is used when unmarshaling headers
	o.LL("type standardHeadersMarshalProxy struct {")
	for _, f := range fields {
		if f.name == jwkKey {
			o.L("X%s json.RawMessage %s", f.name, f.jsonTag)
		} else {
			o.L("X%s %s %s", f.name, fieldStorageType(f.typ), f.jsonTag)
		}
	}
	o.L("}") // end type StandardHeaders

	o.LL("func NewHeaders() Headers {")
	o.L("return &stdHeaders{")
	o.L("mu: &sync.RWMutex{},")
	o.L("}")
	o.L("}")

	for _, f := range fields {
		o.LL("func (h *stdHeaders) %s() %s{", f.method, f.typ)
		o.L("h.mu.RLock()")
		o.L("defer h.mu.RUnlock()")
		if fieldStorageTypeIsIndirect(f.typ) {
			o.L("if h.%s == nil {", f.name)
			o.L("return %s", zeroval(f.typ))
			o.L("}")
			o.L("return *(h.%s)", f.name)
		} else {
			o.L("return h.%s", f.name)
		}
		o.L("}") // func (h *stdHeaders) %s() %s
	}

	// Generate a function that iterates through all of the keys
	// in this header.
	o.LL("func (h *stdHeaders) makePairs() []*HeaderPair {")
	o.L("h.mu.RLock()")
	o.L("defer h.mu.RUnlock()")
	// NOTE: building up an array is *slow*?
	o.L("var pairs []*HeaderPair")
	for _, f := range fields {
		o.L("if h.%s != nil {", f.name)
		if fieldStorageTypeIsIndirect(f.typ) {
			o.L("pairs = append(pairs, &HeaderPair{Key: %sKey, Value: *(h.%s)})", f.method, f.name)
		} else {
			o.L("pairs = append(pairs, &HeaderPair{Key: %sKey, Value: h.%s})", f.method, f.name)
		}
		o.L("}")
	}
	o.L("for k, v := range h.privateParams {")
	o.L("pairs = append(pairs, &HeaderPair{Key: k, Value: v})")
	o.L("}")
	o.L("sort.Slice(pairs, func(i, j int) bool {")
	o.L("return pairs[i].Key.(string) < pairs[j].Key.(string)")
	o.L("})")
	o.L("return pairs")
	o.L("}") // end of (h *stdHeaders) iterate(...)

	o.LL("func (h *stdHeaders) PrivateParams() map[string]interface{} {")
	o.L("h.mu.RLock()")
	o.L("defer h.mu.RUnlock()")
	o.L("return h.privateParams")
	o.L("}")

	o.LL("func (h *stdHeaders) Get(name string) (interface{}, bool) {")
	o.L("h.mu.RLock()")
	o.L("defer h.mu.RUnlock()")
	o.L("switch name {")
	for _, f := range fields {
		o.L("case %sKey:", f.method)
		o.L("if h.%s == nil {", f.name)
		o.L("return nil, false")
		o.L("}")
		if fieldStorageTypeIsIndirect(f.typ) {
			o.L("return *(h.%s), true", f.name)
		} else {
			o.L("return h.%s, true", f.name)
		}
	}
	o.L("default:")
	o.L("v, ok := h.privateParams[name]")
	o.L("return v, ok")
	o.L("}") // end switch name
	o.L("}") // func (h *stdHeaders) Get(name string) (interface{}, bool)

	o.LL("func (h *stdHeaders) Set(name string, value interface{}) error {")
	o.L("h.mu.Lock()")
	o.L("defer h.mu.Unlock()")
	o.L("return h.setNoLock(name, value)")
	o.L("}")

	o.LL("func (h *stdHeaders) setNoLock(name string, value interface{}) error {")
	o.L("switch name {")
	for _, f := range fields {
		o.L("case %sKey:", f.method)
		if f.hasAccept {
			o.L("var acceptor %s", f.PointerElem())
			o.L("if err := acceptor.Accept(value); err != nil {")
			o.L("return errors.Wrapf(err, `invalid value for %%s key`, %sKey)", f.method)
			o.L("}") // end if err := h.%s.Accept(value)
			o.L("h.%s = &acceptor", f.name)
			o.L("return nil")
		} else {
			o.L("if v, ok := value.(%s); ok {", f.typ)
			if fieldStorageTypeIsIndirect(f.typ) {
				o.L("h.%s = &v", f.name)
			} else {
				o.L("h.%s = v", f.name)
			}
			o.L("return nil")
			o.L("}") // end if v, ok := value.(%s)
			o.L("return errors.Errorf(`invalid value for %%s key: %%T`, %sKey, value)", f.method)
		}
	}
	o.L("default:")
	o.L("if h.privateParams == nil {")
	o.L("h.privateParams = map[string]interface{}{}")
	o.L("}") // end if h.privateParams == nil
	o.L("h.privateParams[name] = value")
	o.L("}") // end switch name
	o.L("return nil")
	o.L("}")

	o.LL("func (h *stdHeaders) Remove(key string) error {")
	o.L("h.mu.Lock()")
	o.L("defer h.mu.Unlock()")
	o.L("switch key {")
	for _, f := range fields {
		o.L("case %sKey:", f.method)
		o.L("h.%s = nil", f.name)
	}
	o.L("default:")
	o.L("delete(h.privateParams, key)")
	o.L("}")
	o.L("return nil") // currently unused, but who knows
	o.L("}")

	o.LL("func (h *stdHeaders) UnmarshalJSON(buf []byte) error {")
	for _, f := range fields {
		o.L("h.%s = nil", f.name)
	}

	o.L("dec := json.NewDecoder(bytes.NewReader(buf))")
	o.L("LOOP:")
	o.L("for {")
	o.L("tok, err := dec.Token()")
	o.L("if err != nil {")
	o.L("return errors.Wrap(err, `error reading token`)")
	o.L("}")
	o.L("switch tok := tok.(type) {")
	o.L("case json.Delim:")
	o.L("// Assuming we're doing everything correctly, we should ONLY")
	o.L("// get either '{' or '}' here.")
	o.L("if tok == '}' { // End of object")
	o.L("break LOOP")
	o.L("} else if tok != '{' {")
	o.L("return errors.Errorf(`expected '{', but got '%%c'`, tok)")
	o.L("}")
	o.L("case string: // Objects can only have string keys")
	o.L("switch tok {")

	for _, f := range fields {
		if f.typ == "string" {
			o.L("case %sKey:", f.method)
			o.L("if err := json.AssignNextStringToken(&h.%s, dec); err != nil {", f.name)
			o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", f.method)
			o.L("}")
		} else if f.typ == "[]byte" {
			name := f.method
			o.L("case %sKey:", name)
			o.L("if err := json.AssignNextBytesToken(&h.%s, dec); err != nil {", f.name)
			o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", name)
			o.L("}")
		} else if f.typ == "jwk.Key" {
			name := f.method
			o.L("case %sKey:", name)
			o.L("var buf json.RawMessage")
			o.L("if err := dec.Decode(&buf); err != nil {")
			o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", name)
			o.L("}")
			o.L("key, err := jwk.ParseKey(buf)")
			o.L("if err != nil {")
			o.L("return errors.Wrapf(err, `failed to parse JWK for key %%s`, %sKey)", name)
			o.L("}")
			o.L("h.%s = key", f.name)
		} else if strings.HasPrefix(f.typ, "[]") {
			name := f.method
			o.L("case %sKey:", name)
			o.L("var decoded %s", f.typ)
			o.L("if err := dec.Decode(&decoded); err != nil {")
			o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", name)
			o.L("}")
			o.L("h.%s = decoded", f.name)
		} else {
			name := f.method
			o.L("case %sKey:", name)
			o.L("var decoded %s", f.typ)
			o.L("if err := dec.Decode(&decoded); err != nil {")
			o.L("return errors.Wrapf(err, `failed to decode value for key %%s`, %sKey)", name)
			o.L("}")
			o.L("h.%s = &decoded", f.name)
		}
	}
	o.L("default:")
	o.L("decoded, err := registry.Decode(dec, tok)")
	o.L("if err != nil {")
	o.L("return err")
	o.L("}")
	o.L("h.setNoLock(tok, decoded)")
	o.L("}")
	o.L("default:")
	o.L("return errors.Errorf(`invalid token %%T`, tok)")
	o.L("}")
	o.L("}")

	o.L("return nil")
	o.L("}")

	o.LL("func (h stdHeaders) MarshalJSON() ([]byte, error) {")
	o.L("buf := pool.GetBytesBuffer()")
	o.L("defer pool.ReleaseBytesBuffer(buf)")
	o.L("buf.WriteByte('{')")
	o.L("enc := json.NewEncoder(buf)")
	o.L("for i, p := range h.makePairs() {")
	o.L("if i > 0 {")
	o.L("buf.WriteRune(',')")
	o.L("}")
	o.L("buf.WriteRune('\"')")
	o.L("buf.WriteString(p.Key.(string))")
	o.L("buf.WriteString(`\":`)")
	o.L("v := p.Value")
	o.L("switch v := v.(type) {")
	o.L("case []byte:")
	o.L("buf.WriteRune('\"')")
	o.L("buf.WriteString(base64.EncodeToString(v))")
	o.L("buf.WriteRune('\"')")
	o.L("default:")
	o.L("if err := enc.Encode(v); err != nil {")
	o.L("errors.Errorf(`failed to encode value for field %%s`, p.Key)")
	o.L("}")
	o.L("buf.Truncate(buf.Len()-1)")
	o.L("}")
	o.L("}")
	o.L("buf.WriteByte('}')")
	o.L("ret := make([]byte, buf.Len())")
	o.L("copy(ret, buf.Bytes())")
	o.L("return ret, nil")
	o.L("}")

	if err := o.WriteFile(`headers_gen.go`, codegen.WithFormatCode(true)); err != nil {
		if cfe, ok := err.(codegen.CodeFormatError); ok {
			fmt.Fprint(os.Stderr, cfe.Source())
		}
		return errors.Wrap(err, `failed to write to headers_gen.go`)
	}
	return nil
}
