1
0
treerack/format.go

639 lines
12 KiB
Go
Raw Normal View History

2026-05-30 20:14:24 +02:00
package treerack
import (
"bytes"
"fmt"
"io"
"strings"
"unicode"
)
const initialTargetWidth = 112
type commentFormat int
const (
commentFormatNone commentFormat = iota
standaloneComment
headerComment
suffixComment
inlineComment
)
type formatItem struct {
commentFormat commentFormat
node *Node
}
type formatGroup struct {
items []formatItem
}
func topLevelCommentFormat(ast *Node, i int, n *Node) commentFormat {
if n.Name != "comment" {
return commentFormatNone
}
if i > 0 &&
ast.Nodes[i-1].Name != "comment" &&
!strings.Contains(string(ast.tokens[ast.Nodes[i-1].To:n.From]), "\n") {
return suffixComment
}
if len(ast.Nodes) == i+1 {
return standaloneComment
}
next := ast.Nodes[i+1]
if next.Name == "comment" {
return standaloneComment
}
var lines int
space := ast.tokens[n.To:next.From]
for _, s := range space {
if s == '\n' {
lines++
}
}
if lines >= 2 {
return standaloneComment
}
return headerComment
}
func groupASTByComments(ast *Node) []formatGroup {
var (
groups []formatGroup
currentGroup formatGroup
)
for i, n := range ast.Nodes {
last := len(currentGroup.items) - 1
cf := topLevelCommentFormat(ast, i, n)
item := formatItem{
commentFormat: cf,
node: n,
}
if cf == commentFormatNone {
if last >= 0 && currentGroup.items[last].commentFormat == standaloneComment {
groups = append(groups, currentGroup)
currentGroup.items = nil
}
currentGroup.items = append(currentGroup.items, item)
continue
}
if cf == suffixComment {
currentGroup.items = append(currentGroup.items, item)
continue
}
if last >= 0 {
groups = append(groups, currentGroup)
}
currentGroup.items = []formatItem{item}
}
groups = append(groups, currentGroup)
return groups
}
func trimComment(text string) string {
var inBlockComment, inLineComment bool
tr := []rune(text)
rr := make([]rune, 0, len(tr))
for i := 0; i < len(tr); i++ {
r := tr[i]
if inBlockComment {
if r != '*' || len(tr) <= i+1 || tr[i+1] != '/' {
rr = append(rr, r)
continue
}
rr = append(rr, '*', '/')
inBlockComment = false
if len(tr) > i+2 && !unicode.IsSpace(tr[i+2]) {
rr = append(rr, ' ')
}
i++
continue
}
if inLineComment {
rr = append(rr, r)
inLineComment = r != '\n'
continue
}
if r == '/' && len(tr) > i+1 && tr[i+1] == '*' {
rr = append(rr, '/', '*')
inBlockComment = true
i++
continue
}
if r == '/' && len(tr) > i+1 && tr[i+1] == '/' {
rr = append(rr, '/', '/')
inLineComment = true
if len(tr) > i+2 && tr[i+2] != ' ' {
rr = append(rr, ' ')
}
i++
continue
}
if r == '\n' || len(rr) > 0 && !unicode.IsSpace(rr[len(rr)-1]) {
rr = append(rr, r)
}
}
lines := strings.Split(string(rr), "\n")
for i := range lines {
lines[i] = strings.TrimRightFunc(lines[i], unicode.IsSpace)
}
return strings.Join(lines, "\n")
}
func formatComment(out io.Writer, n *Node) error {
text := n.Text()
text = trimComment(text)
_, err := fmt.Fprint(out, text)
return err
}
func formatDefinitionName(item formatItem) string {
name := item.node.Nodes[0].Text()
flags := make([]string, 0, len(item.node.Nodes)-2)
for i := 1; i < len(item.node.Nodes)-1; i++ {
flags = append(flags, item.node.Nodes[i].Name)
}
if len(flags) > 0 {
name += ":" + strings.Join(flags, ":")
}
return name
}
func formatItemNames(g formatGroup) ([]string, int) {
var maxWidth int
ordered := make([]string, 0, len(g.items))
for _, item := range g.items {
if item.commentFormat != commentFormatNone {
ordered = append(ordered, "")
continue
}
name := formatDefinitionName(item)
maxWidth = max(maxWidth, len(name))
ordered = append(ordered, name)
}
return ordered, maxWidth
}
func formatAnyChar(out io.Writer) error {
_, err := fmt.Fprint(out, ".")
return err
}
func formatCharClass(out io.Writer, n *Node) error {
_, err := fmt.Fprint(out, n.Text())
return err
}
func formatCharSequence(out io.Writer, n *Node) error {
_, err := fmt.Fprint(out, n.Text())
return err
}
func formatSymbol(out io.Writer, n *Node) error {
_, err := fmt.Fprint(out, n.Text())
return err
}
func decTargetWidth(w, by int) int {
if w <= 0 {
return w
}
w -= by
if w < 0 {
w = 0
}
return w
}
func formatSequenceItemNode(out io.Writer, targetWidth int, n *Node) error {
var (
min, max int
err error
)
fprint := func(a ...any) {
if err != nil {
return
}
_, err = fmt.Fprint(out, a...)
}
if len(n.Nodes) == 2 {
if min, max, err = getQuantity(n.Nodes[1]); err != nil {
return err
}
}
min, max = normalizeItemRange(min, max)
needsQuantifier := min != 1 || max != 1
isChoice := n.Nodes[0].Name == "choice"
isChoiceOfMultiple := isChoice && len(n.Nodes[0].Nodes) > 1
isSequence := n.Nodes[0].Name == "sequence"
isSequenceOfMultiple := isSequence && len(n.Nodes[0].Nodes) > 1
needsGrouping := isChoiceOfMultiple || isSequenceOfMultiple
if needsGrouping {
var buf bytes.Buffer
targetWidth = decTargetWidth(targetWidth, 2)
if err := formatExpression(&buf, targetWidth, n.Nodes[0]); err != nil {
return err
}
multiline := strings.Contains(buf.String(), "\n")
if multiline {
lines := strings.Split(buf.String(), "\n")
fprint("( ")
fprint(lines[0])
for _, l := range lines[1:] {
fprint("\n ")
fprint(l)
}
fprint("\n )")
} else {
fprint("(")
if _, err := io.Copy(out, &buf); err != nil {
return err
}
fprint(")")
}
} else {
if err := formatExpression(out, targetWidth, n.Nodes[0]); err != nil {
return err
}
}
if !needsQuantifier {
return nil
}
if min == 0 && max == 1 {
fprint("?")
return err
}
if min == 0 && max < 0 {
fprint("*")
return err
}
if min == 1 && max < 0 {
fprint("+")
return err
}
fprint("{")
if min == max {
fprint(min)
} else {
if min > 0 {
fprint(min)
}
fprint(",")
if max >= 0 {
fprint(max)
}
}
return err
}
func formatSequenceItemNodes(out io.Writer, targetWidth int, n []*Node) error {
sep := " "
if targetWidth >= 0 {
sep = "\n "
}
for i, ni := range n {
if i > 0 {
if _, err := fmt.Fprint(out, sep); err != nil {
return err
}
}
if ni.Name == "comment" {
if err := formatComment(out, ni); err != nil {
return err
}
continue
}
if err := formatSequenceItemNode(out, targetWidth, ni); err != nil {
return err
}
}
return nil
}
func formatSequence(out io.Writer, targetWidth int, n []*Node) error {
var buf bytes.Buffer
if err := formatSequenceItemNodes(&buf, -1, n); err != nil {
return err
}
if targetWidth >= 0 && buf.Len() > targetWidth {
(&buf).Reset()
if err := formatSequenceItemNodes(&buf, targetWidth, n); err != nil {
return err
}
}
_, err := io.Copy(out, &buf)
return err
}
func formatChoiceOptionNodes(out io.Writer, targetWidth int, n []*Node) error {
sep, commentSep := " | ", " "
if targetWidth >= 0 {
sep, commentSep = "\n| ", "\n"
}
for i, ni := range n {
if ni.Name == "comment" {
if i > 0 {
if _, err := fmt.Fprint(out, commentSep); err != nil {
return err
}
}
if err := formatComment(out, ni); err != nil {
return err
}
continue
}
if i > 0 {
if _, err := fmt.Fprint(out, sep); err != nil {
return err
}
}
if err := formatExpression(out, targetWidth, ni); err != nil {
return err
}
}
return nil
}
func formatChoice(out io.Writer, targetWidth int, n []*Node) error {
var buf bytes.Buffer
if err := formatChoiceOptionNodes(&buf, -1, n); err != nil {
return err
}
if targetWidth >= 0 && buf.Len() > targetWidth {
(&buf).Reset()
if err := formatChoiceOptionNodes(&buf, targetWidth, n); err != nil {
return err
}
}
_, err := io.Copy(out, &buf)
return err
}
func formatExpression(out io.Writer, targetWidth int, n *Node) error {
var err error
switch n.Name {
case "comment":
err = formatComment(out, n)
case "any-char":
err = formatAnyChar(out)
case "char-class":
err = formatCharClass(out, n)
case "char-sequence":
err = formatCharSequence(out, n)
case "symbol":
err = formatSymbol(out, n)
case "sequence":
err = formatSequence(out, targetWidth, n.Nodes)
case "choice":
err = formatChoice(out, targetWidth, n.Nodes)
}
return err
}
func formatDefinition(out io.Writer, targetWidth, namesWidth int, pad, name string, n *Node) error {
if _, err := fmt.Fprintf(out, "%s%s = ", name, pad[:namesWidth-len(name)]); err != nil {
return err
}
var buf bytes.Buffer
targetWidth = decTargetWidth(targetWidth, namesWidth+3)
if err := formatExpression(&buf, targetWidth, n.Nodes[len(n.Nodes)-1]); err != nil {
return err
}
var err error
fprint := func(a ...any) {
if err != nil {
return
}
_, err = fmt.Fprint(out, a...)
}
lines := strings.Split(buf.String(), "\n")
fprint(lines[0])
for _, l := range lines[1:] {
fprint("\n ")
fprint(pad)
fprint(l)
}
fprint(";")
return err
}
func formatASTGroup(out io.Writer, g formatGroup) error {
if g.items[0].commentFormat == standaloneComment {
return formatComment(out, g.items[0].node)
}
hasHeaderComment := g.items[0].commentFormat == headerComment
if hasHeaderComment {
if err := formatComment(out, g.items[0].node); err != nil {
return err
}
g.items = g.items[1:]
}
names, namesWidth := formatItemNames(g)
pad := strings.Join(make([]string, namesWidth+1), " ")
for i, item := range g.items {
name := names[i]
if item.commentFormat == suffixComment {
if _, err := fmt.Fprint(out, " "); err != nil {
return err
}
formatComment(out, item.node)
continue
}
if hasHeaderComment || i > 0 {
if _, err := fmt.Fprintln(out); err != nil {
return err
}
}
if err := formatDefinition(
out,
initialTargetWidth,
namesWidth,
pad,
name,
item.node,
); err != nil {
return err
}
}
return nil
}
func formatAST(out io.Writer, ast *Node) error {
// drop whitespace comments
// use line comments by default
// comment types:
// - standalone comment
// - header comment
// - suffix comment
// - inline comment
//
// standalone comment:
// - preceeded by definition or at least two empty lines and followed by at least two empty lines
// - separate it by two empty lines above and below
//
// header comment:
// - separated from the subsequent definition by zero or one empty lines
// - separate it by two empty lines above and one empty line below
//
// suffix comment:
// - starts on the same line as the definition it belongs to
// - append to the definition
// - if it consists of multiple lines, append a new line below
//
// inline comment:
// - it's inside a definition
// - if it's before the eq sign, discard name padding and use block comment
// - if it's in an expression, and falls on its own line, and fits on the previous line, put it there
// - if it's in an expression, and falls on its own line, use a line comment
// - if it's in an expression, and it's followed by non-comment on the same line, use block comment
// - if it consists of multiple lines, append a new line below the definition
groups := groupASTByComments(ast)
for i, g := range groups {
if i > 0 {
if _, err := fmt.Fprint(out, "\n\n"); err != nil {
return err
}
}
if err := formatASTGroup(out, g); err != nil {
return err
}
}
return nil
}
func formatDefinitions(out io.Writer, s *Syntax) error {
var o formatOptions
o.mode = formatPretty
o.targetWidth = initialTargetWidth
var (
namesWidth int
orderedDefs []string
)
defs := make(map[string]definition)
for _, def := range s.registry.definitions {
if def.commitType()&userDefined == 0 {
continue
}
defName := def.nodeName()
ct := def.commitType()
ct &^= userDefined
if sq, ok := def.(*sequenceDefinition); ok && sq.isCharSequence(s.registry) {
ct &^= NoWhitespace
}
if ct != None {
defName = fmt.Sprintf("%s:%v", defName, ct)
}
orderedDefs = append(orderedDefs, defName)
defs[defName] = def
namesWidth = max(namesWidth, len([]rune(defName)))
}
o.targetWidth = decTargetWidth(o.targetWidth, namesWidth+3)
pad := strings.Join(make([]string, namesWidth+1), " ")
for _, name := range orderedDefs {
def := defs[name]
f := def.format(s.registry, o)
lines := strings.Split(f, "\n")
if _, err := fmt.Fprintf(
out,
"%s%s = %s",
name,
pad[:namesWidth-len(name)],
lines[0],
); err != nil {
return err
}
for _, l := range lines[1:] {
if _, err := fmt.Fprintf(out, "\n%s %s", pad, l); err != nil {
return err
}
}
if _, err := fmt.Fprint(out, ";\n"); err != nil {
return err
}
}
return nil
}