# AST 解析
抽象语法树(AST)可以帮助我们更好地处理 python 的源码,是静态分析中常用的工具。下面让我们一起来看看如何利用 AST 来处理 python 源码。
# 获取源码 AST
首先,我们来处理一个简单的 python 语句 print('hello!')
1 2 3 4 5 6 import astfrom astpretty import pprintsource_code = "print('hello!')" abstract_tree = ast.parse(source_code) pprint(abstract_tree)
print('hello!')
的 AST 如下所示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 Module( body=[ Expr( lineno=1, col_offset=0, value=Call( lineno=1, col_offset=0, func=Name(lineno=1, col_offset=0, id='print' , ctx=Load()), args=[Str(lineno=1, col_offset=6, s='hello!' )], keywords=[], ), ), ], )
如上所示,就是一个简单的 python 代码的 AST 表示。
更多关于 AST 语法相关的内容可以参考【官方文档】 。
# 遍历 AST 节点
得到 AST 表示显然不是我们的最终目的,我们想要能够借助 AST 来获取更多信息。
举个例子,我们想要获取 python 源码中所有函数定义的名称。下面我们就开始探索一下,如何利用 AST 来达成我们的目的。
sample.py
1 2 def foo (): print ('func foo' )
利用上述方法我们可以得到 sample.py
的 AST 表示,如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 Module( body=[ FunctionDef( lineno=1, col_offset=0, name='foo' , args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[ Expr( lineno=2, col_offset=4, value=Call( lineno=2, col_offset=4, func=Name(lineno=2, col_offset=4, id='print' , ctx=Load()), args=[Str(lineno=2, col_offset=10, s='func foo' )], keywords=[], ), ), ], decorator_list=[], returns=None, ), ], )
观察可知,我们只需要拿到 name='foo'
这个信息就可以,那么我们怎么拿到这个值呢?
AST 是一个个节点构成的树状结构,通过遍历 AST 上的每个节点,我们可以获得其相对应的属性信息。例如 Module
就是一个节点,其包含一个 body
的属性。
先上代码 visit_ast.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import astcontent = open ("sample.py" ).read() script_ast = ast.parse(content) class MyVisit (ast.NodeVisitor ): def visit_FunctionDef (self,node ): print (node.name) return node m = MyVisit() m.visit(script_ast)
输出结果
我们继承了 ast.NodeVisitor
类,重载了访问 FunctionDef
节点的方法,将该节点对应的 node.name
进行了输出。如果你想存储函数的名称,可以在__init__函数中定一个数组,来存储这些 function name。
但上述代码存在着一个问题,当访问完 FunctionDef
节点后,我们选择了直接 return,这样会导致 FunctionDef
中的子节点无法被访问。如果存在函数嵌套定义,我们就无法拿到所有的函数名称。我们修改 sample.py
的内容,如下所示
1 2 3 4 def foo (): print ('hello!' ) def bar (): print ('world!' )
执行 visit_ast.py
,得到的输出结果
为了解决这个问题,我们需要看一下 ast 中关于 NodeVisitor 类的定义
# ast.NodeVisitor
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 class NodeVisitor (object ): """ A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. ... """ def visit (self, node ): """Visit a node.""" method = 'visit_' + node.__class__.__name__ visitor = getattr (self, method, self.generic_visit) return visitor(node) def generic_visit (self, node ): """Called if no explicit visitor function exists for a node.""" for field, value in iter_fields(node): if isinstance (value, list ): for item in value: if isinstance (item, AST): self.visit(item) elif isinstance (value, AST): self.visit(value) getattr (object , name[, default])返回一个对象的属性name,如果没有则返回默认值 getattr (self, method, self.generic_visit)通过继承 NodeVisitor,并定义method方法,就能实现多种类型节点的遍历
解析一下这个代码
visit(self, node)
会先通过获取节点对应的类名,来获取对应的访问方法。最开始我们传递的 script_ast
就是一个 Module
,那么 method
即为 visit_Module
。
visitor = getattr(self, method, self.generic_visit)
这里涉及一个 Python 的函数 getattr
,其原始的函数定义为 getattr(object, name[, default])
,这个函数的作用是返回对象 object 的一个名为 name 的属性,如果不存在则返回默认值 default。
在这个语境下的作用就是返回 NodeVisitor.{method}
这个属性,即返回 NodeVisitor.visit_Module
。如果这个属性没有定义,那么返回 NodeVisitor.generic_visit
。这样就可以理解为什么我们定义了一个 visit_FunctionDef
函数,就可以处理该节点的相关信息。
generic_visit(self, node)
这个是默认的节点遍历函数,从函数定义可知其会利用 visit
函数来遍历当前节点中所有键值对的值。
回顾我们的需求,我们想在访问 FunctionDef
节点后,还能够处理其子节点中的 FunctionDef
,结合这些代码,我们可以修正我们的遍历程序,如下所示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import astcontent = open ("sample.py" ).read() script_ast = ast.parse(content) class MyVisit (ast.NodeVisitor ): def visit_FunctionDef (self, node ): print (node.name) self.generic_visit(node) m = MyVisit() m.visit(script_ast)
运行结果
我们在访问第一个 FunctionDef
节点之后,使用 generic_visit
访问其子节点,结合上面函数定义,我们就可以访问到嵌套定义的第二个函数 bar
。
# AST 修改
既然不同的源码对应不同的 AST,那么如果我们直接在 AST 层面修改,其对应的源码是否也会发生变化呢?答案是 yes,接下来我们就通过一个简单的例子学习一下如何修改 AST,以及如果通过 AST 生成 Python 源码。
sample.py
1 2 def foo (): print ('func foo' )
AST 表示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 Module( body=[ FunctionDef( lineno=1, col_offset=0, name='foo', args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[ Expr( lineno=2, col_offset=4, value=Call( lineno=2, col_offset=4, func=Name(lineno=2, col_offset=4, id='print', ctx=Load()), args=[Str(lineno=2, col_offset=10, s='func foo')], keywords=[], ), ), ], decorator_list=[], returns=None, ), ], )
假设我们想把上面 sample.py
中输出的内容修改成 hello world!
, ast
这个库也提供了相应的方法供我们参考。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import astimport astunparsecontent = open ("sample.py" ).read() script_ast = ast.parse(content) class MyTransformer (ast.NodeTransformer ): def visit_Str (self, node ): node.s = 'hello world!' return node m = MyTransformer() modified = m.visit(script_ast) print (astunparse.Unparser(modified))
运行结果
1 2 def foo (): print ('hello world!' )
同理,我们看一下 ast
中关于 NodeTransformer
的定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 class NodeTransformer (NodeVisitor ): """ A :class:`NodeVisitor` subclass that walks the abstract syntax tree and allows modification of nodes. ... """ def generic_visit (self, node ): for field, old_value in iter_fields(node): if isinstance (old_value, list ): new_values = [] for value in old_value: if isinstance (value, AST): value = self.visit(value) if value is None : continue elif not isinstance (value, AST): new_values.extend(value) continue new_values.append(value) old_value[:] = new_values elif isinstance (old_value, AST): new_node = self.visit(old_value) if new_node is None : delattr (node, field) else : setattr (node, field, new_node) return node
NodeTransformer
继承了 NodeVisitor
,所以定义遍历节点的方法是一样的。我们想修改 Str
节点,所以就重载了 visit_Str
这个函数,并对其进行修改。
generic_visit(self, node)
我们直接观察第二个分支语句
1 2 3 4 5 6 elif isinstance (old_value, AST): new_node = self.visit(old_value) if new_node is None : delattr (node, field) else : setattr (node, field, new_node)
当遍历 old_value
时,会得到一个返回值 new_node
,然后如果返回值为 None
,则代表 old_value
被删除了,所以对当前遍历的节点 node
进行了删除属性操作 delattr(node, field)
。若不为 None
,则用新的值替换原有的 old_value
,即 setattr(node, field, new_node)
所以当我们遍历修改完 node.s
之后,我们需要将修改后的节点进行返回。
当这个节点 node
的处理都结束后, generic_visit
会返回修改后的节点。也就完成了 AST 的重构
通过第三方库函数 astunparse.Unparser
,我们可以将修改后的 AST 还原为源码。