ANTLR4构建AST


Challenge:使用ANTLR4构建AST

1.定义语法规则

编写Calculator.g4文件,如下

grammar Calculator;

prog    : stat+ ;

stat    : expr NEWLINE              # printExpr
        ;

expr    : expr '*' expr             # mul
        | expr '/' expr             # div
        | expr '+' expr             # add
        | expr '-' expr             # sub
        | '(' expr ')'              # bracket
        | INT                       # int
        ;

NEWLINE :'\r'? '\n' ;

INT     : [0-9]+ ;

MUL     : '*' ;
DIV     : '/' ;
ADD     : '+' ;
SUB     : '-' ;

WS      : [ \t\r\n]+ -> skip ;

2.定义对AST访问的动作

输入

antlr4 -no-listener -visitor -Dlanguage=Python3 Calculator.g4

即可自动生成lexer、parser、visitor且均由python编写

image-20221120003913606

创建一个myVisitor类,继承自CalculatorVisitor,定义对AST的访问动作

from antlr4 import *
from CalculatorLexer import CalculatorLexer
from CalculatorParser import CalculatorParser
from CalculatorVisitor import CalculatorVisitor


class myVisitor(CalculatorVisitor):

    def visitProg(self, ctx: CalculatorParser.ProgContext):
        array = []
        for i in range(ctx.getChildCount()):
            array.append(self.visit(ctx.stat(i)))
        return array

    def visitPrintExpr(self, ctx: CalculatorParser.PrintExprContext):
        val = self.visit(ctx.expr())
        print(val)
        return val;

    def visitDiv(self, ctx: CalculatorParser.DivContext):
        return self.visit(ctx.expr(0)) / self.visit(ctx.expr(1))

    def visitAdd(self, ctx: CalculatorParser.AddContext):
        return self.visit(ctx.expr(0)) + self.visit(ctx.expr(1))

    def visitSub(self, ctx: CalculatorParser.SubContext):
        return self.visit(ctx.expr(0)) - self.visit(ctx.expr(1))

    def visitMul(self, ctx: CalculatorParser.MulContext):
        return self.visit(ctx.expr(0)) * self.visit(ctx.expr(1))

    def visitBracket(self, ctx: CalculatorParser.BracketContext):
        return self.visit(ctx.expr())

    def visitInt(self, ctx: CalculatorParser.IntContext):
        return int(ctx.getText())

3.测试

给定写好的代数表达式文件expr.txt

(1 + 2) * 3
1 + 2 * 3
1 + 2 + 3

main.py定义函数antlr

array = antlr()
answer = [9, 7, 6]

def antlr():
    input = FileStream("expr.txt")
    lexer = CalculatorLexer(input)
    stream = CommonTokenStream(lexer)
    parser = CalculatorParser(stream)
    tree = parser.prog()
    visitor = myVisitor()
    array = visitor.visitProg(tree)
    return array

在单元测试中加入

def test_antlr(self):
    for i in range(len(array)):
        self.assertEqual(array[i], answer[i])

最终效果

image-20221120004523313

Author: Paranoid
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint policy. If reproduced, please indicate source Paranoid !
评论
  TOC