enumgen

package
v0.0.7 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 1, 2024 License: BSD-3-Clause Imports: 19 Imported by: 0

Documentation

Overview

Package enumgen provides functions for generating enum methods for enum types.

Index

Constants

This section is empty.

Variables

View Source
var AllowedEnumTypes = map[string]bool{"int": true, "int64": true, "int32": true, "int16": true, "int8": true, "uint": true, "uint64": true, "uint32": true, "uint16": true, "uint8": true}

AllowedEnumTypes are the types that can be used for enums that are not bit flags (bit flags can only be int64s). It is stored as a map for quick and convenient access.

View Source
var DescMethodTmpl = template.Must(template.New("DescMethod").Parse(`// Desc returns the description of the {{.Name}} value.
func (i {{.Name}}) Desc() string {
	if str, ok := _{{.Name}}DescMap[i]; ok {
		return str
	} {{if eq .Extends ""}}
	return i.String() {{else}}
	return {{.Extends}}(i).Desc() {{end}}
}
`))
View Source
var GQLMethodsTmpl = template.Must(template.New("GQLMethods").Parse(`
// MarshalGQL implements the [graphql.Marshaler] interface.
func (i {{.Name}}) MarshalGQL(w io.Writer) {
	w.Write([]byte(strconv.Quote(i.String())))
}

// UnmarshalGQL implements the [graphql.Unmarshaler] interface.
func (i *{{.Name}}) UnmarshalGQL(value any) error {
	str, ok := value.(string)
	if !ok {
		return fmt.Errorf("{{.Name}} should be a string, but got a value of type %T instead", value)
	}
	return i.SetString(str)
}
`))
View Source
var HasFlagMethodTmpl = template.Must(template.New("HasFlagMethod").Parse(
	`// HasFlag returns whether these
// bit flags have the given bit flag set.
func (i {{.Name}}) HasFlag(f enums.BitFlag) bool {
	return atomic.LoadInt64((*int64)(&i))&(1<<uint32(f.Int64())) != 0
}
`))
View Source
var Int64MethodTmpl = template.Must(template.New("Int64Method").Parse(
	`// Int64 returns the {{.Name}} value as an int64.
func (i {{.Name}}) Int64() int64 {
	return int64(i)
}
`))
View Source
var IsValidMethodMapTmpl = template.Must(template.New("IsValidMethodMap").Parse(
	`// IsValid returns whether the value is a
// valid option for type {{.Name}}.
func (i {{.Name}}) IsValid() bool {
	_, ok := _{{.Name}}Map[i] {{if ne .Extends ""}}
	if !ok {
		return {{.Extends}}(i).IsValid()
	} {{end}}
	return ok
}
`))
View Source
var JSONMethodsTmpl = template.Must(template.New("JSONMethods").Parse(
	`
// MarshalJSON implements the [json.Marshaler] interface.
func (i {{.Name}}) MarshalJSON() ([]byte, error) {
	return json.Marshal(i.String())
}

// UnmarshalJSON implements the [json.Unmarshaler] interface.
func (i *{{.Name}}) UnmarshalJSON(data []byte) error {
	var s string
	if err := json.Unmarshal(data, &s); err != nil {
		return err
	}
	if err := i.SetString(s); err != nil {
		log.Println("{{.Name}}.UnmarshalJSON:", err)
	}
	return nil
}
`))
View Source
var NConstantTmpl = template.Must(template.New("StringNConstant").Parse(
	`//{{.Name}}N is the highest valid value
// for type {{.Name}}, plus one.
const {{.Name}}N {{.Name}} = {{.MaxValueP1}}
`))
View Source
var ScanMethodTmpl = template.Must(template.New("ScanMethod").Parse(
	`// Value implements the [sql.Scanner] interface.
func (i *{{.Name}}) Scan(value any) error {
	if value == nil {
		return nil
	}

	var str string
	switch v := value.(type) {
	case []byte:
		str = string(v)
	case string:
		str = v
	case fmt.Stringer:
		str = v.String()
	default:
		return fmt.Errorf("invalid value for type {{.Name}}: %[1]T(%[1]v)", value)
	}

	return i.SetString(str)
}
`))
View Source
var SetFlagMethodTmpl = template.Must(template.New("SetFlagMethod").Parse(
	`// SetFlag sets the value of the given
// flags in these flags to the given value.
func (i *{{.Name}}) SetFlag(on bool, f ...enums.BitFlag) {
	var mask int64
	for _, v := range f {
		mask |= 1 << v.Int64()
	}
	in := int64(*i)
	if on {
		in |= mask
		atomic.StoreInt64((*int64)(i), in)
	} else {
		in &^= mask
		atomic.StoreInt64((*int64)(i), in)
	}
}
`))
View Source
var SetInt64MethodTmpl = template.Must(template.New("SetInt64Method").Parse(
	`// SetInt64 sets the {{.Name}} value from an int64.
func (i *{{.Name}}) SetInt64(in int64) {
	*i = {{.Name}}(in)
}
`))
View Source
var SetStringMethodBitFlagTmpl = template.Must(template.New("SetStringMethodBitFlag").Parse(
	`// SetString sets the {{.Name}} value from its
// string representation, and returns an
// error if the string is invalid.
func (i *{{.Name}}) SetString(s string) error {
	*i = 0
	return i.SetStringOr(s)
}
`))
View Source
var SetStringMethodTmpl = template.Must(template.New("SetStringMethod").Parse(
	`// SetString sets the {{.Name}} value from its
// string representation, and returns an
// error if the string is invalid.
func (i *{{.Name}}) SetString(s string) error {
	if val, ok := _{{.Name}}NameToValueMap[s]; ok {
		*i = val
		return nil
	} {{if .Config.AcceptLower}}
	if val, ok := _{{.Name}}NameToValueMap[strings.ToLower(s)]; ok {
		*i = val
		return nil
	} {{end}} {{if eq .Extends ""}}
	return errors.New(s+" is not a valid value for type {{.Name}}") {{else}}
	return (*{{.Extends}})(i).SetString(s) {{end}}
}
`))
View Source
var SetStringOrMethodBitFlagTmpl = template.Must(template.New("SetStringOrMethodBitFlag").Parse(
	`// SetStringOr sets the {{.Name}} value from its
// string representation while preserving any
// bit flags already set, and returns an
// error if the string is invalid.
func (i *{{.Name}}) SetStringOr(s string) error {
	flgs := strings.Split(s, "|")
	for _, flg := range flgs {
		if val, ok := _{{.Name}}NameToValueMap[flg]; ok {
			i.SetFlag(true, &val)
		{{if .Config.AcceptLower}} } else if val, ok := _{{.Name}}NameToValueMap[strings.ToLower(flg)]; ok {
			i.SetFlag(true, &val)
		{{end}} } else if flg == "" {
			continue
		} else { {{if eq .Extends ""}}
			return fmt.Errorf("%q is not a valid value for type {{.Name}}", flg){{else}}
			err := (*{{.Extends}})(i).SetStringOr(flg)
			if err != nil {
				return err
			}{{end}}
		}
	}
	return nil
}
`))
View Source
var StringMethodBitFlagTmpl = template.Must(template.New("StringMethodBitFlag").Parse(
	`// String returns the string representation
// of this {{.Name}} value.
func (i {{.Name}}) String() string {
	str := "" {{if ne .Extends ""}}
	for _, ie := range {{.Extends}}Values() {
		if i.HasFlag(ie) {
			ies := ie.BitIndexString()
			if str == "" {
				str = ies
			} else {
				str += "|" + ies
			}
		}
	}{{end}}
	for _, ie := range _{{.Name}}Values {
		if i.HasFlag(ie) {
			ies := ie.BitIndexString()
			if str == "" {
				str = ies
			} else {
				str += "|" + ies
			}
		}
	}
	return str
}
`))
View Source
var StringMethodMapTmpl = template.Must(template.New("StringMethodMap").Parse(
	`{{if .IsBitFlag}}
	// BitIndexString returns the string
	// representation of this {{.Name}} value
	// if it is a bit index value
	// (typically an enum constant), and
	// not an actual bit flag value.
	{{- else}}
	// String returns the string representation
	// of this {{.Name}} value.
	{{- end}}
func (i {{.Name}}) {{if .IsBitFlag}} BitIndexString {{else}} String {{end}} () string {
	if str, ok := _{{.Name}}Map[i]; ok {
		return str
	} {{if eq .Extends ""}}
	return strconv.FormatInt(int64(i), 10) {{else}}
	return {{.Extends}}(i).{{if .IsBitFlag}} BitIndexString {{else}} String {{end}}() {{end}}
}
`))
View Source
var TextMethodsTmpl = template.Must(template.New("TextMethods").Parse(
	`
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i {{.Name}}) MarshalText() ([]byte, error) {
	return []byte(i.String()), nil
}

// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *{{.Name}}) UnmarshalText(text []byte) error {
	if err := i.SetString(string(text)); err != nil {
		log.Println("{{.Name}}.UnmarshalText:", err)
	}
	return nil
}
`))
View Source
var ValueMethodTmpl = template.Must(template.New("ValueMethod").Parse(
	`// Scan implements the [driver.Valuer] interface.
func (i {{.Name}}) Value() (driver.Value, error) {
	return i.String(), nil
}
`))
View Source
var ValuesGlobalTmpl = template.Must(template.New("ValuesGlobal").Parse(
	`// {{.Name}}Values returns all possible values
// for the type {{.Name}}.
func {{.Name}}Values() []{{.Name}} { {{if eq .Extends ""}}
	return _{{.Name}}Values {{else}}
	es := {{.Extends}}Values()
	res := make([]{{.Name}}, len(es))
	for i, e := range es {
		res[i] = {{.Name}}(e)
	}
	res = append(res, _{{.Name}}Values...)
	return res {{end}}
}
`))
View Source
var ValuesMethodTmpl = template.Must(template.New("ValuesMethod").Parse(
	`// Values returns all possible values
// for the type {{.Name}}.
func (i {{.Name}}) Values() []enums.Enum { {{if eq .Extends ""}}
	res := make([]enums.Enum, len(_{{.Name}}Values))
	for i, d := range _{{.Name}}Values {
		res[i] = d
	} {{else}}
	es := {{.Extends}}Values()
	les := len(es)
	res := make([]enums.Enum, les + len(_{{.Name}}Values))
	for i, d := range es {
		res[i] = d
	}
	for i, d := range _{{.Name}}Values {
		res[i + les] = d
	} {{end}}
	return res 
}
`))
View Source
var YAMLMethodsTmpl = template.Must(template.New("YAMLMethods").Parse(
	`
// MarshalYAML implements the [yaml.Marshaler] interface.
func (i {{.Name}}) MarshalYAML() (any, error) {
	return i.String(), nil
}

// UnmarshalYAML implements the [yaml.Unmarshaler] interface.
func (i *{{.Name}}) UnmarshalYAML(value *yaml.Node) error {
	var s string
	if err := n.Decode(&s); err != nil {
		return err
	}
	if err := i.SetString(s); err != nil {
		log.Println("{{.Name}}.UnmarshalYAML:", err)
	}
	return nil
}
`))

Functions

func Generate

func Generate(cfg *Config) error

Generate generates enum methods, using the configuration information, loading the packages from the configuration source directory, and writing the result to the configuration output file.

It is a simple entry point to enumgen that does all of the steps; for more specific functionality, create a new Generator with NewGenerator and call methods on it.

func GeneratePkgs

func GeneratePkgs(cfg *Config, pkgs []*packages.Package) error

GeneratePkgs generates enum methods using the given configuration object and packages parsed from the configuration source directory, and writes the result to the config output file. It is a simple entry point to enumgen that does all of the steps; for more specific functionality, create a new Generator with NewGenerator and call methods on it.

func PackageModes

func PackageModes() packages.LoadMode

PackageModes returns the package load modes needed for this generator

func ParsePackages

func ParsePackages(cfg *Config) ([]*packages.Package, error)

ParsePackages parses the package(s) located in the configuration source directory.

Types

type ByValue

type ByValue []Value

ByValue is a sorting method that sorts the constants into increasing order. We take care in the Less method to sort in signed or unsigned order, as appropriate.

func (ByValue) Len

func (b ByValue) Len() int

func (ByValue) Less

func (b ByValue) Less(i, j int) bool

func (ByValue) Swap

func (b ByValue) Swap(i, j int)

type Config

type Config struct {

	// the source directory to run enumgen on (can be set to multiple through paths like ./...)
	Dir string `default:"." posarg:"0" required:"-"`

	// the output file location relative to the package on which enumgen is being called
	Output string `default:"enumgen.go"`

	// if specified, the enum item transformation method (upper, lower, snake, SNAKE, kebab, KEBAB,
	// camel, lower-camel, title, sentence, first, first-upper, or first-lower)
	Transform string

	// if specified, a comma-separated list of prefixes to trim from each item
	TrimPrefix string

	// if specified, the prefix to add to each item
	AddPrefix string

	// whether to use line comment text as printed text when present
	LineComment bool

	// whether to accept lowercase versions of enum names in SetString
	AcceptLower bool

	// whether to generate text marshaling methods
	Text bool `default:"true"`

	// whether to generate JSON marshaling methods (note that text marshaling methods will also work for JSON, so this should be unnecessary in almost all cases; see the text option)
	JSON bool

	// whether to generate YAML marshaling methods
	YAML bool

	// whether to generate methods that implement the SQL Scanner and Valuer interfaces
	SQL bool

	// whether to generate GraphQL marshaling methods for gqlgen
	GQL bool

	// whether to allow enums to extend other enums; this should be on in almost all circumstances,
	// but can be turned off for specific enum types that extend non-enum types
	Extend bool `default:"true"`
}

Config contains the configuration information used by enumgen

type Generator

type Generator struct {
	Config *Config             // The configuration information
	Buf    bytes.Buffer        // The accumulated output.
	Pkgs   []*packages.Package // The packages we are scanning.
	Pkg    *packages.Package   // The packages we are currently on.
	Types  []*Type             // The enum types
}

Generator holds the state of the generator. It is primarily used to buffer the output.

func NewGenerator

func NewGenerator(config *Config, pkgs []*packages.Package) *Generator

NewGenerator returns a new generator with the given configuration information and parsed packages.

func (*Generator) AddValueAndScanMethod

func (g *Generator) AddValueAndScanMethod(typ *Type)

func (*Generator) BuildBasicMethods

func (g *Generator) BuildBasicMethods(values []Value, typ *Type)

BuildBasicMethods builds methods common to all types, like Desc and SetString.

func (*Generator) BuildBitFlagMethods

func (g *Generator) BuildBitFlagMethods(runs []Value, typ *Type)

BuildBitFlagMethods builds methods specific to bit flag types.

func (*Generator) BuildGQLMethods

func (g *Generator) BuildGQLMethods(runs []Value, typ *Type)

func (*Generator) BuildJSONMethods

func (g *Generator) BuildJSONMethods(runs []Value, typ *Type)

func (*Generator) BuildNoOpOrderChangeDetect

func (g *Generator) BuildNoOpOrderChangeDetect(values []Value, typ *Type)

BuildNoOpOrderChangeDetect lets the compiler and the user know if the order/value of the enum values has changed.

func (*Generator) BuildString

func (g *Generator) BuildString(values []Value, typ *Type)

BuildString builds the string function using a map access approach.

func (*Generator) BuildTextMethods

func (g *Generator) BuildTextMethods(runs []Value, typ *Type)

func (*Generator) BuildYAMLMethods

func (g *Generator) BuildYAMLMethods(runs []Value, typ *Type)

func (*Generator) ExecTmpl

func (g *Generator) ExecTmpl(t *template.Template, typ *Type)

ExecTmpl executes the given template with the given type and writes the result to [Generator.Buf]. It fatally logs any error. All enumgen templates take a Type as their data.

func (*Generator) FindEnumTypes

func (g *Generator) FindEnumTypes() error

FindEnumTypes goes through all of the types in the package and finds all integer (signed or unsigned) types labeled with enums:enum or enums:bitflag. It stores the resulting types in [Generator.Types].

func (*Generator) GenDecl

func (g *Generator) GenDecl(node ast.Node, file *ast.File, typ *Type) ([]Value, bool, error)

GenDecl processes one declaration clause. It returns whether the AST inspector should continue, and an error if there is one. It should only be called in ast.Inspect.

func (*Generator) Generate

func (g *Generator) Generate() (bool, error)

Generate produces the enum methods for the types stored in [Generator.Types] and stores them in [Generator.Buf]. It returns whether there were any enum types to generate methods for, and any error that occurred.

func (*Generator) InspectForType

func (g *Generator) InspectForType(n ast.Node) (bool, error)

InspectForType looks at the given AST node and adds it to [Generator.Types] if it is marked with an appropriate comment directive. It returns whether the AST inspector should continue, and an error if there is one. It should only be called in ast.Inspect.

func (*Generator) PrefixValueNames

func (g *Generator) PrefixValueNames(values []Value, c *Config)

PrefixValueNames adds the prefix specified in [Config.AddPrefix] to each name of the given values.

func (*Generator) PrintDescMap

func (g *Generator) PrintDescMap(values []Value, typ *Type)

PrintDescMap prints the map of values to descriptions

func (*Generator) PrintHeader

func (g *Generator) PrintHeader()

PrintHeader prints the header and package clause to the accumulated output

func (*Generator) PrintValueMap

func (g *Generator) PrintValueMap(values []Value, typ *Type)

PrintValueMap prints the map between name and value

func (*Generator) Printf

func (g *Generator) Printf(format string, args ...any)

Printf prints the formatted string to the accumulated output in [Generator.Buf]

func (*Generator) TransformValueNames

func (g *Generator) TransformValueNames(values []Value, c *Config) error

TransformValueNames transforms the names of the given values according to the transform method specified in [Config.Transform]

func (*Generator) TrimValueNames

func (g *Generator) TrimValueNames(values []Value, c *Config)

TrimValueNames removes the prefixes specified in [Config.TrimPrefix] from each name of the given values.

func (*Generator) Write

func (g *Generator) Write() error

Write formats the data in the the Generator's buffer ([Generator.Buf]) and writes it to the file specified by [Generator.Config.Output].

type Type

type Type struct {
	Name       string        // The name of the type
	Type       *ast.TypeSpec // The standard AST type value
	IsBitFlag  bool          // Whether the type is a bit flag type
	Extends    string        // The type that this type extends, if any ("" if it doesn't extend)
	MaxValueP1 int64         // the highest defined value for the type, plus one
	Config     *Config       // Configuration information set in the comment directive for the type; is initialized to generator config info first
}

Type represents a parsed enum type.

type Value

type Value struct {
	OriginalName string // The name of the constant before transformation
	Name         string // The name of the constant after transformation (i.e. camel case => snake case)
	Desc         string // The comment description of the constant
	// The Value is stored as a bit pattern alone. The boolean tells us
	// whether to interpret it as an int64 or a uint64; the only place
	// this matters is when sorting.
	// Much of the time the str field is all we need; it is printed
	// by Value.String.
	Value  int64
	Signed bool   // Whether the constant is a signed type.
	Str    string // The string representation given by the "go/constant" package.
}

Value represents a declared constant.

func SortValues

func SortValues(values []Value) []Value

SortValues sorts the values and ensures there are no duplicates. The input slice is known to be non-empty.

func (*Value) String

func (v *Value) String() string

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL