test: address litgen review feedback

This commit is contained in:
ZhouGuangyuan
2026-04-17 23:16:29 +08:00
parent 8cb9e72c1f
commit cb0feeb190
6 changed files with 166 additions and 149 deletions
+3 -37
View File
@@ -17,13 +17,13 @@
package main
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/goplus/llgo/cl/cltest"
"github.com/goplus/llgo/internal/littest"
"github.com/goplus/llgo/internal/llgen"
"github.com/goplus/mod"
)
@@ -107,43 +107,9 @@ func check(err error) {
}
func dirHasLITTESTSource(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
_, ok, err := littest.FindMarkedSourceFile(dir)
if err != nil {
return false, err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if filepath.Ext(name) != ".go" || strings.HasSuffix(name, "_test.go") {
continue
}
ok, err := hasLITTESTMarker(filepath.Join(dir, name))
if err != nil {
return false, err
}
if ok {
return true, nil
}
}
return false, nil
}
func hasLITTESTMarker(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
return false, err
}
defer f.Close()
scanner := bufio.NewScanner(f)
if !scanner.Scan() {
return false, scanner.Err()
}
line := strings.TrimSpace(scanner.Text())
if !strings.HasPrefix(line, "//") {
return false, nil
}
return strings.TrimSpace(strings.TrimPrefix(line, "//")) == "LITTEST", nil
return ok, nil
}
+8 -5
View File
@@ -21,6 +21,8 @@ import (
"fmt"
"os"
"path/filepath"
"github.com/goplus/llgo/internal/littest"
)
func main() {
@@ -34,7 +36,7 @@ func main() {
os.Exit(2)
}
for _, arg := range flag.Args() {
check(processPath(arg))
fatal(processPath(arg))
}
}
@@ -53,7 +55,7 @@ func processPath(path string) error {
if filepath.Ext(abs) != ".go" {
return fmt.Errorf("%s: expected .go file or directory", abs)
}
ok, err := hasLITTESTMarker(abs)
ok, err := littest.HasMarker(abs)
if err != nil {
return err
}
@@ -80,7 +82,7 @@ func processTree(root string) error {
if path != root && len(d.Name()) > 0 && d.Name()[0] == '_' {
return filepath.SkipDir
}
marked, found, err := findMarkedSourceFile(path)
marked, found, err := littest.FindMarkedSourceFile(path)
if err != nil {
return err
}
@@ -109,8 +111,9 @@ func processTree(root string) error {
return nil
}
func check(err error) {
func fatal(err error) {
if err != nil {
panic(err)
fmt.Fprintf(os.Stderr, "litgen: %v\n", err)
os.Exit(1)
}
}
+58 -73
View File
@@ -17,7 +17,6 @@
package main
import (
"bufio"
"fmt"
"go/ast"
"go/format"
@@ -28,15 +27,15 @@ import (
"path/filepath"
"reflect"
"regexp"
"runtime/debug"
"sort"
"strings"
"github.com/goplus/llgo/internal/llgen"
"github.com/goplus/mod"
"golang.org/x/mod/modfile"
)
const littestMarker = "LITTEST"
type resolvedTarget struct {
sourceFile string
genTarget string
@@ -90,7 +89,7 @@ func generateFile(target resolvedTarget) error {
if err != nil {
return fmt.Errorf("%s: gofmt failed: %w", target.sourceFile, err)
}
return os.WriteFile(target.sourceFile, formatted, 0644)
return writeFileAtomically(target.sourceFile, formatted, 0644)
}
func resolveTarget(sourceFile, genTarget string) (resolvedTarget, error) {
@@ -119,33 +118,30 @@ func resolveTarget(sourceFile, genTarget string) (resolvedTarget, error) {
func genIR(target string) (ret string, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("llgen failed for %s: %v", target, r)
switch v := r.(type) {
case error:
err = fmt.Errorf("llgen failed for %s: %w", target, v)
case string:
err = fmt.Errorf("llgen failed for %s: %s", target, v)
default:
_, _ = os.Stderr.Write(debug.Stack())
panic(r)
}
}
}()
return llgen.GenFrom(target), nil
}
func readModulePath(goMod string) (string, error) {
f, err := os.Open(goMod)
data, err := os.ReadFile(goMod)
if err != nil {
return "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "//") {
continue
}
if strings.HasPrefix(line, "module ") {
return strings.TrimSpace(strings.TrimPrefix(line, "module ")), nil
}
modulePath := modfile.ModulePath(data)
if modulePath == "" {
return "", fmt.Errorf("%s: module directive not found", goMod)
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", fmt.Errorf("%s: module directive not found", goMod)
return modulePath, nil
}
func packagePath(modulePath, root, pkgDir string) (string, error) {
@@ -199,6 +195,7 @@ func rewriteSource(src, srcPath, pkgPath, modulePath, ir string) (string, error)
func collectAnchors(src string, fset *token.FileSet, file *ast.File) (map[string]int, int) {
anchors := make(map[string]int)
counts := make(map[string]int)
topPos := topInsertPos(src, fset, file)
if initPos, ok := syntheticInitPos(src, fset, file); ok {
anchors["init"] = initPos
@@ -208,12 +205,12 @@ func collectAnchors(src string, fset *token.FileSet, file *ast.File) (map[string
case *ast.FuncDecl:
name := inPkgFuncName(d)
anchors[name] = declInsertPos(src, fset, d.Pos(), d.Doc)
collectFuncLitAnchors(src, fset, d.Body, name, anchors)
collectFuncLitAnchors(src, fset, d.Body, name, anchors, counts)
case *ast.GenDecl:
if d.Tok == token.IMPORT {
continue
}
collectFuncLitAnchors(src, fset, d, "init", anchors)
collectFuncLitAnchors(src, fset, d, "init", anchors, counts)
}
}
return anchors, topPos
@@ -252,11 +249,10 @@ func declInsertPos(src string, fset *token.FileSet, pos token.Pos, doc *ast.Comm
return lineStart(src, offsetOf(fset, pos))
}
func collectFuncLitAnchors(src string, fset *token.FileSet, node ast.Node, parent string, anchors map[string]int) {
func collectFuncLitAnchors(src string, fset *token.FileSet, node ast.Node, parent string, anchors map[string]int, counts map[string]int) {
if isNilNode(node) {
return
}
counts := make(map[string]int)
var walk func(ast.Node, string)
walk = func(root ast.Node, current string) {
if isNilNode(root) {
@@ -449,7 +445,26 @@ func generalizeModulePath(line, modulePath string) string {
if modulePath == "" {
return line
}
return strings.ReplaceAll(line, modulePath, "{{.*}}")
var b strings.Builder
start := 0
inQuote := false
for i := 0; i < len(line); i++ {
if line[i] != '"' {
continue
}
if !inQuote {
b.WriteString(line[start : i+1])
start = i + 1
inQuote = true
continue
}
b.WriteString(strings.ReplaceAll(line[start:i], modulePath, "{{.*}}"))
b.WriteByte('"')
start = i + 1
inQuote = false
}
b.WriteString(line[start:])
return b.String()
}
func shouldCheckGlobal(symbol string) bool {
@@ -563,57 +578,27 @@ func offsetOf(fset *token.FileSet, pos token.Pos) int {
return fset.PositionFor(pos, false).Offset
}
func findMarkedSourceFile(dir string) (string, bool, error) {
entries, err := os.ReadDir(dir)
func writeFileAtomically(path string, data []byte, perm os.FileMode) (err error) {
dir := filepath.Dir(path)
tmp, err := os.CreateTemp(dir, ".litgen-*")
if err != nil {
return "", false, err
return err
}
var marked string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !isSourceSpecFile(name) {
continue
}
path := filepath.Join(dir, name)
ok, err := hasLITTESTMarker(path)
tmpPath := tmp.Name()
defer func() {
_ = tmp.Close()
if err != nil {
return "", false, err
_ = os.Remove(tmpPath)
}
if !ok {
continue
}
if marked != "" {
return "", false, fmt.Errorf("%s: multiple // LITTEST sources found: %s, %s", dir, filepath.Base(marked), name)
}
marked = path
}()
if err = tmp.Chmod(perm); err != nil {
return err
}
if marked == "" {
return "", false, nil
if _, err = tmp.Write(data); err != nil {
return err
}
return marked, true, nil
}
func hasLITTESTMarker(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
return false, err
}
defer f.Close()
scanner := bufio.NewScanner(f)
if !scanner.Scan() {
return false, scanner.Err()
}
line := strings.TrimSpace(scanner.Text())
if !strings.HasPrefix(line, "//") {
return false, nil
}
return strings.TrimSpace(strings.TrimPrefix(line, "//")) == littestMarker, nil
}
func isSourceSpecFile(name string) bool {
return filepath.Ext(name) == ".go" && !strings.HasSuffix(name, "_test.go")
if err = tmp.Close(); err != nil {
return err
}
return os.Rename(tmpPath, path)
}
+49
View File
@@ -183,3 +183,52 @@ _llgo_0:
t.Fatalf("global checks should be placed before first declaration:\n%s", got)
}
}
func TestRewriteSource_SharesInitClosureCountsAcrossDecls(t *testing.T) {
const src = `// LITTEST
package main
var a = func() int { return 1 }()
var b = func() int { return 2 }()
`
const ir = `define void @"example.com/p.init"() {
_llgo_0:
%0 = call i64 @"example.com/p.init$1"()
%1 = call i64 @"example.com/p.init$2"()
ret void
}
define i64 @"example.com/p.init$1"() {
_llgo_0:
ret i64 1
}
define i64 @"example.com/p.init$2"() {
_llgo_0:
ret i64 2
}
`
got, err := rewriteSource(src, "in.go", "example.com/p", "example.com", ir)
if err != nil {
t.Fatal(err)
}
firstCheck := `// CHECK-LABEL: define i64 @"{{.*}}/p.init$1"() {`
secondCheck := `// CHECK-LABEL: define i64 @"{{.*}}/p.init$2"() {`
firstVar := "var a = func() int { return 1 }()"
secondVar := "var b = func() int { return 2 }()"
if strings.Index(got, firstCheck) > strings.Index(got, firstVar) {
t.Fatalf("first init closure should be anchored before first var decl:\n%s", got)
}
if strings.Index(got, secondCheck) > strings.Index(got, secondVar) {
t.Fatalf("second init closure should be anchored before second var decl:\n%s", got)
}
}
func TestGeneralizeModulePath_ReplacesOnlyQuotedSegments(t *testing.T) {
line := ` %0 = getelementptr inbounds %"go/example.Type", ptr @"go/example.fn"`
got := generalizeModulePath(line, "go")
want := ` %0 = getelementptr inbounds %"{{.*}}/example.Type", ptr @"{{.*}}/example.fn"`
if got != want {
t.Fatalf("generalizeModulePath = %q, want %q", got, want)
}
}
+42 -28
View File
@@ -42,7 +42,7 @@ type Spec struct {
Mode Mode
}
const marker = "LITTEST"
const Marker = "LITTEST"
func LoadSpec(pkgDir string) (Spec, error) {
if spec, ok, err := loadSourceSpec(pkgDir); err != nil {
@@ -79,31 +79,12 @@ func Check(spec Spec, actual string) error {
}
func loadSourceSpec(pkgDir string) (Spec, bool, error) {
entries, err := os.ReadDir(pkgDir)
marked, ok, err := FindMarkedSourceFile(pkgDir)
if err != nil {
return Spec{}, false, err
}
var marked string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !isSourceSpecFile(name) {
continue
}
path := filepath.Join(pkgDir, name)
ok, err := hasMarker(path)
if err != nil {
return Spec{}, false, err
}
if !ok {
continue
}
if marked != "" {
return Spec{}, false, fmt.Errorf("%s: multiple source lit specs found: %s, %s", pkgDir, filepath.Base(marked), filepath.Base(path))
}
marked = path
if !ok {
return Spec{}, false, nil
}
if marked == "" {
return Spec{}, false, nil
@@ -113,12 +94,12 @@ func loadSourceSpec(pkgDir string) (Spec, bool, error) {
return Spec{}, false, err
}
text := string(data)
ok, err := filecheck.HasDirectives(text)
ok, err = filecheck.HasDirectives(text)
if err != nil {
return Spec{}, false, err
}
if !ok {
return Spec{}, false, fmt.Errorf("%s: %s is marked %s but has no FileCheck directives", pkgDir, filepath.Base(marked), marker)
return Spec{}, false, fmt.Errorf("%s: %s is marked %s but has no FileCheck directives", pkgDir, filepath.Base(marked), Marker)
}
return Spec{
Path: marked,
@@ -127,7 +108,40 @@ func loadSourceSpec(pkgDir string) (Spec, bool, error) {
}, true, nil
}
func hasMarker(path string) (bool, error) {
func FindMarkedSourceFile(dir string) (string, bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return "", false, err
}
var marked string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !IsSourceSpecFile(name) {
continue
}
path := filepath.Join(dir, name)
ok, err := HasMarker(path)
if err != nil {
return "", false, err
}
if !ok {
continue
}
if marked != "" {
return "", false, fmt.Errorf("%s: multiple source lit specs found: %s, %s", dir, filepath.Base(marked), filepath.Base(path))
}
marked = path
}
if marked == "" {
return "", false, nil
}
return marked, true, nil
}
func HasMarker(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
return false, err
@@ -142,9 +156,9 @@ func hasMarker(path string) (bool, error) {
if !strings.HasPrefix(line, "//") {
return false, nil
}
return strings.TrimSpace(strings.TrimPrefix(line, "//")) == marker, nil
return strings.TrimSpace(strings.TrimPrefix(line, "//")) == Marker, nil
}
func isSourceSpecFile(name string) bool {
func IsSourceSpecFile(name string) bool {
return filepath.Ext(name) == ".go" && !strings.HasSuffix(name, "_test.go")
}
+6 -6
View File
@@ -177,9 +177,9 @@ func TestLoadSpecSupportsSkipOutLL(t *testing.T) {
func TestHasMarker(t *testing.T) {
dir := t.TempDir()
ok, err := hasMarker(filepath.Join(dir, "missing.go"))
ok, err := HasMarker(filepath.Join(dir, "missing.go"))
if err == nil || ok {
t.Fatalf("hasMarker(missing) = (%v, %v)", ok, err)
t.Fatalf("HasMarker(missing) = (%v, %v)", ok, err)
}
empty := filepath.Join(dir, "empty.go")
@@ -187,9 +187,9 @@ func TestHasMarker(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ok, err = hasMarker(empty)
ok, err = HasMarker(empty)
if err != nil || ok {
t.Fatalf("hasMarker(empty) = (%v, %v)", ok, err)
t.Fatalf("HasMarker(empty) = (%v, %v)", ok, err)
}
plain := filepath.Join(dir, "plain.go")
@@ -197,9 +197,9 @@ func TestHasMarker(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ok, err = hasMarker(plain)
ok, err = HasMarker(plain)
if err != nil || ok {
t.Fatalf("hasMarker(plain) = (%v, %v)", ok, err)
t.Fatalf("HasMarker(plain) = (%v, %v)", ok, err)
}
}