# AST 解析

抽象语法树(AST)可以帮助我们更好地处理 python 的源码,是静态分析中常用的工具。下面让我们一起来看看如何利用 AST 来处理 python 源码。

# 获取源码 AST

首先,我们来处理一个简单的 python 语句 print('hello!')

1
2
3
4
5
6
import ast
from astpretty import pprint

source_code = "print('hello!')"
abstract_tree = ast.parse(source_code)
pprint(abstract_tree)

print('hello!') 的 AST 如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Module(
body=[
Expr(
lineno=1,
col_offset=0,
value=Call(
lineno=1,
col_offset=0,
func=Name(lineno=1, col_offset=0, id='print', ctx=Load()),
args=[Str(lineno=1, col_offset=6, s='hello!')],
keywords=[],
),
),
],
)

如上所示,就是一个简单的 python 代码的 AST 表示。

更多关于 AST 语法相关的内容可以参考【官方文档】

# 遍历 AST 节点

得到 AST 表示显然不是我们的最终目的,我们想要能够借助 AST 来获取更多信息。

举个例子,我们想要获取 python 源码中所有函数定义的名称。下面我们就开始探索一下,如何利用 AST 来达成我们的目的。

sample.py

1
2
def foo():
print('func foo')

利用上述方法我们可以得到 sample.py 的 AST 表示,如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
Module(
body=[
FunctionDef(
lineno=1,
col_offset=0,
name='foo',
args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[
Expr(
lineno=2,
col_offset=4,
value=Call(
lineno=2,
col_offset=4,
func=Name(lineno=2, col_offset=4, id='print', ctx=Load()),
args=[Str(lineno=2, col_offset=10, s='func foo')],
keywords=[],
),
),
],
decorator_list=[],
returns=None,
),
],
)

观察可知,我们只需要拿到 name='foo' 这个信息就可以,那么我们怎么拿到这个值呢?

AST 是一个个节点构成的树状结构,通过遍历 AST 上的每个节点,我们可以获得其相对应的属性信息。例如 Module 就是一个节点,其包含一个 body 的属性。

先上代码 visit_ast.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import ast

# 读取源码并构建ast
content = open("sample.py").read()
script_ast = ast.parse(content)

# 定义一个访问类
class MyVisit(ast.NodeVisitor):
def visit_FunctionDef(self,node):
print(node.name)
return node

m = MyVisit()
m.visit(script_ast)

输出结果

1
foo

我们继承了 ast.NodeVisitor 类,重载了访问 FunctionDef 节点的方法,将该节点对应的 node.name 进行了输出。如果你想存储函数的名称,可以在__init__函数中定一个数组,来存储这些 function name。

但上述代码存在着一个问题,当访问完 FunctionDef 节点后,我们选择了直接 return,这样会导致 FunctionDef 中的子节点无法被访问。如果存在函数嵌套定义,我们就无法拿到所有的函数名称。我们修改 sample.py 的内容,如下所示

1
2
3
4
def foo():
print('hello!')
def bar():
print('world!')

执行 visit_ast.py ,得到的输出结果

1
foo

为了解决这个问题,我们需要看一下 ast 中关于 NodeVisitor 类的定义

# ast.NodeVisitor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class NodeVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a
visitor function for every node found. ...
"""

def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)

def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, AST):
self.visit(item)
elif isinstance(value, AST):
self.visit(value)

# 代码解读
getattr(object, name[, default])
返回一个对象的属性name,如果没有则返回默认值

getattr(self, method, self.generic_visit)
通过继承 NodeVisitor,并定义method方法,就能实现多种类型节点的遍历

解析一下这个代码

  • visit(self, node)

    会先通过获取节点对应的类名,来获取对应的访问方法。最开始我们传递的 script_ast 就是一个 Module ,那么 method 即为 visit_Module

    visitor = getattr(self, method, self.generic_visit) 这里涉及一个 Python 的函数 getattr ,其原始的函数定义为 getattr(object, name[, default]) ,这个函数的作用是返回对象 object 的一个名为 name 的属性,如果不存在则返回默认值 default。

    在这个语境下的作用就是返回 NodeVisitor.{method} 这个属性,即返回 NodeVisitor.visit_Module 。如果这个属性没有定义,那么返回 NodeVisitor.generic_visit 。这样就可以理解为什么我们定义了一个 visit_FunctionDef 函数,就可以处理该节点的相关信息。

  • generic_visit(self, node)

    这个是默认的节点遍历函数,从函数定义可知其会利用 visit 函数来遍历当前节点中所有键值对的值。

回顾我们的需求,我们想在访问 FunctionDef 节点后,还能够处理其子节点中的 FunctionDef ,结合这些代码,我们可以修正我们的遍历程序,如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import ast

# 读取源码并构建ast
content = open("sample.py").read()
script_ast = ast.parse(content)

# 定义一个访问类
class MyVisit(ast.NodeVisitor):
def visit_FunctionDef(self, node):
print(node.name)
self.generic_visit(node)

m = MyVisit()
m.visit(script_ast)

运行结果

1
2
foo
bar

我们在访问第一个 FunctionDef 节点之后,使用 generic_visit 访问其子节点,结合上面函数定义,我们就可以访问到嵌套定义的第二个函数 bar

# AST 修改

既然不同的源码对应不同的 AST,那么如果我们直接在 AST 层面修改,其对应的源码是否也会发生变化呢?答案是 yes,接下来我们就通过一个简单的例子学习一下如何修改 AST,以及如果通过 AST 生成 Python 源码。

sample.py

1
2
def foo():
print('func foo')

AST 表示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
Module(
body=[
FunctionDef(
lineno=1,
col_offset=0,
name='foo',
args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[
Expr(
lineno=2,
col_offset=4,
value=Call(
lineno=2,
col_offset=4,
func=Name(lineno=2, col_offset=4, id='print', ctx=Load()),
args=[Str(lineno=2, col_offset=10, s='func foo')],
keywords=[],
),
),
],
decorator_list=[],
returns=None,
),
],
)

假设我们想把上面 sample.py 中输出的内容修改成 hello world!ast 这个库也提供了相应的方法供我们参考。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import ast
import astunparse

# 读取源码并构建ast
content = open("sample.py").read()
script_ast = ast.parse(content)

# 定义一个访问类
class MyTransformer(ast.NodeTransformer):
def visit_Str(self, node):
node.s = 'hello world!'
return node

m = MyTransformer()
modified = m.visit(script_ast)
print(astunparse.Unparser(modified))

运行结果

1
2
def foo():
print('hello world!')

同理,我们看一下 ast 中关于 NodeTransformer 的定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class NodeTransformer(NodeVisitor):
"""
A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
allows modification of nodes. ...
"""

def generic_visit(self, node):
for field, old_value in iter_fields(node):
if isinstance(old_value, list):
new_values = []
for value in old_value:
if isinstance(value, AST):
value = self.visit(value)
if value is None:
continue
elif not isinstance(value, AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, AST):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node

NodeTransformer 继承了 NodeVisitor ,所以定义遍历节点的方法是一样的。我们想修改 Str 节点,所以就重载了 visit_Str 这个函数,并对其进行修改。

  • generic_visit(self, node)

    我们直接观察第二个分支语句

    1
    2
    3
    4
    5
    6
    elif isinstance(old_value, AST):
    new_node = self.visit(old_value)
    if new_node is None:
    delattr(node, field)
    else:
    setattr(node, field, new_node)

    当遍历 old_value 时,会得到一个返回值 new_node ,然后如果返回值为 None ,则代表 old_value 被删除了,所以对当前遍历的节点 node 进行了删除属性操作 delattr(node, field) 。若不为 None ,则用新的值替换原有的 old_value ,即 setattr(node, field, new_node)

    所以当我们遍历修改完 node.s 之后,我们需要将修改后的节点进行返回。

    当这个节点 node 的处理都结束后, generic_visit 会返回修改后的节点。也就完成了 AST 的重构

通过第三方库函数 astunparse.Unparser ,我们可以将修改后的 AST 还原为源码。

更新于 阅读次数

请我喝[茶]~( ̄▽ ̄)~*

chaihj15 微信支付

微信支付

chaihj15 支付宝

支付宝