Learn Python ASTs by building your own linter
So what is an AST?
It's short for Abstract Syntax Tree. In programmer terms, "ASTs are a programmatic way to understand the structure of your source code". But to understand what that really means, we must first understand a few things about the structure of a computer program.
The programs that you and I write in our language of choice is usually called the "source code", and I'll be referring to it as such in this article.
On the other end, computer chips can only understand "machine code", which is a set of binary numbers that have special meanings for that model of the chip. Some of these numbers are instructions, which tell the CPU a simple task to perform, like "add the numbers stored in these two places", or "jump 10 numbers down and continue running code from there". The instructions run one by one, and they dictate the flow of the program.
Similarly, you define your programs as a set of "statements", with each statement being one thing that you want your code to do. They're sort of a more human-friendly version of the CPU instructions, that we can write and reason with more easily.
Now, I know that theory can get boring really quick, so I'm going to go through a bunch of examples. Let's write the same piece of code in many languages, and notice the similarities:
- Python
- Scheme Lisp
- Go
We're doing essentially the same thing in all of these, and I'll break it down piece by piece:
- We're defining our source code as a block of statements. In our case, there are two statements at the top-level of our source code: one statement that defines our area_of_circle function, and another statement that runs this function with the value "5".
- The definition of the area_of_circle function has two parts: the input parameters (the radius, in our case), and the body, which itself is a block of statements. There's two statements inside area_of_circle to be specific: the first one defines pi, and the second one uses it to calculate the area, and returns it.
- For the languages that have a main function, the definition of the main function itself is a statement. Inside that statement we are writing more statements, like one that prints out the value of area_of_circle called with the radius of 5.
You can start to see the somewhat repetitive nature of source code. There's blocks of statements, and sometimes within those statements there can be more statements, and so on. If you imagine each statement to be a "node", then you can think of each of these nodes being composed of one or more other "nodes". You can properly define this kind of structure as a "tree":
The nodes here can be anything, from statements, to expressions, to any other construct that the language defines. Once the code is in this tree structure, computers can start to make sense of it, such as traversing its nodes one by one and generate the appropriate machine code.
Essentially, all your code represents a tree of data. And that tree is called the Abstract Syntax Tree. Each programming language has its own AST representation, but the idea is always the same.
To be able to create tools that do things like auto-format your code, or find subtle bugs automatically, you need ASTs to be able to meaningfully read through the code, find items or patterns inside the code, and act on them.
Python's ast module
Python has a builtin ast module, which has a rich set of features to create, modify and run ASTs from Python code. Not all languages provide easy access to their syntax trees, so Python is already pretty good in that regard. Let's take a look at what all the ast module gives us, and try to do something interesting with it:
All the Nodes
There are lots of kinds of "Nodes" in a Python AST each with their own functionalities, but you can broadly divide them into four categories: Literals, Variables, Statements and Expressions. We'll take a look at them one by one, but before we do that we need to understand how a "Node" is represented.
The role of a node is to concretely represent the features of a language.
It does so by:
- Storing the attributes specific to itself, (for example, an If node that represents an if-statement might need a condition attribute, which is an expression that evaluates to true or false. The if statement's body will only run when condition ends up being true.
- Defining what children the node can have. (In our If node's case, it should have a body, that is a list of statements.)
In Python's case, the AST nodes also hold their exact location in the source code. You can find out from where in the Python file a node came from, by checking the lineno and col_offset parameters.
Let's see the concrete example of this if statement, in Python's AST representation.
For this source code:
The AST looks like this:
Let's break this down:
Ignoring the details for now, the overall structure of the AST looks like this:
At the top level, is a Module. All Python files are compiled as "modules" when making the AST. Modules have a very specific meaning: anything that can be run by Python classifies as a module. So by definition, our Python file is a module.
It has a body, which is a list. Specifically, a list of statements. All Python files are just that: a list of statements. Every Python program that you've ever written, read or run -- just that.
In our case, we have just one statement in the module's body: an If-statement. The if-statement has two components: a test, and a body. The test part holds the condition expression, and the body holds the block of statements that's inside the if.
Let's look at the test "expression" first:
In our case, we have a Compare expression -- which makes sense. Python defines comparisons quite thoroughly in its reference, and if you read it, you'll find that Python supports comparison chaining.
From the docs:
Python's comparison expressions support this syntax:
a op1 b op2 c ...
Which is equivalent to:
a op1 b and b op2 c and ...
In human terms, this means that Python can support stuff like this:
And 0 < x < 10 is the same as asking 0 < x and x < 10.
Here's the important part: for Python to support this, the AST needs to support this. And Python's AST supports comparison chaining by storing the operators and the comparators (variables) inside lists. You can look at it in the REPL itself:
Extras: astpretty
You can see that the operators <, > and > are stored as ops=[Lt(), Gt(), Gt()] inside the Compare object. The four values are stored a bit more peculiarly: The variable a is stored in a separate field called left, and then every other variable is stored in a list called comparators:
In other words: the leftmost variable is stored in left, and every variable on the right of each operator is stored in the respective index of comparators.
Hopefully that clarifies what the test expression means in our example code:
left is the Name node 'answer' (basically, a variable), and we have just one comparison going on: Eq being applied on the constant value 42. Essentially it is the answer == 42 part of the code.
Now let's look at the body:
The body in our case is a single Expression. Note that, when I said that a block or module always contains a list of statements, I wasn't lying. This Expr right here is actually an expression-statement. Yeah, I'm not making this up, it will make sense in a bit.
Expressions vs. Statements
Statements are pretty easy to define. They're kind of like the building blocks of your code. Each statement does something that you can properly define. Such as:
- Creating a variable
- This one becomes an Assign statement:
- Pretty straightforward, the node stores a target and a value. targets here is a list because you can also do multiple assignments: a = b = 5. There will only be one value, though.
- Importing a module
- This one becomes an Import statement:
- Asserting some property
- Becomes:
- Doing absolutely nothing
- Becomes:
On the other hand, an expression is basically anything that evaluates to a value. Any piece of syntax that ends up turning into a "value", such as a number, a string, an object, even a class or function. As long as it returns a value to us, it is an expression.
This includes:
- Identity checks
- This refers to the is expression:
- Clearly, a is b returns either True or False, just like any other conditional check. And since it returns a value, it is an expression.
- Here's its AST:
- And it really is just like conditionals. Turns out is is treated just as a special operator (like <, == and so on) inside a Compare object when talking about ASTs.
- Function calls
- Function calls return a value. That makes them the most obvious example of an expression:
- Here's what the AST for getpid() looks like, it's essentially just a Call:
- print('Hello') would look like this; it has one argument:
- Lambdas
- Lambdas themselves are expressions. When you create a lambda function, you usually pass it directly as an argument to another function, or assign it to a variable. Here's some examples:
- And there's the Lambda expression for lambda: 5 in the AST:
Now, if you think about it, a call to print() in a regular code, it's technically a statement, right?
As I've said before, blocks of code are essentially just a list of statements. And we also know, that calling print is technically an expression (it even returns None!). So what's going on here?
The answer is simple: Python lets you treat any expression as a standalone statement. The expression is going to return some value, but that value just gets discarded.
Getting back to our original AST:
We have an Expr in our body, which is Python's way of saying "This is an expression that's being used as a statement". The actual expression is inside it, a Call to print.
The last thing left in this example AST is the last line: orelse=[]. orelse refers to else: blocks anywhere in the AST. The name orelse was chosen because else itself is a keyword and can't be used as an attribute name.
Oh, did you know that for loops in Python can have an else clause?
Extras: The for-else clause
If you want a detailed reference of all the Nodes that we have in a Python AST, and the corresponding syntax it belongs to, you can either head on to the docs here, or just use ast.dump on a line of code to try and find out for yourself.
What's a ctx?
Ideally, I want you to leave this article understanding every single aspect of Python's ASTs. And if you're one of the few super observant readers, you might have noticed that we glanced over a very small thing in the AST examples shown. You can see it in this code snippet:
We've talked about Compare, we've talked about what the left, ops and comparators fields represent, we've also talked about Name nodes. The only thing left is ctx=Load(). What exactly does that mean?
If you check all the code snippets we've seen so far, we've actually seen 27 instances of Name nodes in the examples. Out of the 27, 25 have had the property ctx=Load(), but two of them have a different value: ctx=Store(). Like this one:
ctx (short for "context") is an essential concept of Python (and many other programming languages), and it is related to the whole concept of "variables".
If I were to ask you "what's a variable?" You might say something like "It can store values which you can use later.", and give some example like:
And that's exactly what it is. If you look at the AST for this code:
So the first statement is an Assign, and the variable age is in the "Store" context (because a new value is being stored into it), and in the second statement it is in "Load" context. Interestingly, print itself is a variable that's being loaded in this statement. Which makes sense, print is essentially a function somewhere in memory, which is accessible by us using the name print.
Let's look at a couple more. What about this?
age = age + 1
The AST looks like this:
The age on the right is in "Load" mode and the one on the left is in "Store" mode. That's why this line of code makes sense: we are Loading the old value, adding 1 to it, and then Storing it.
This should probably help in explaining how this kind of self-assigning code really works to a newbie programmer in the future.
One more interesting example is this:
The AST looks like this:
Here, x is actually in "Load" mode even though it's on the left side of the assignment. And if you think about it, it makes sense. We need to load x, and then modify one of its indices. It's not x which is being assigned to, only one index of it is being assigned. So the part of the AST that is in Store context is the Subscript, i.e. it is x[5] that's being assigned a new value.
Hopefully this explains why we explicitly need to tell each variable whether it is in a load or store context in the AST.
Now unless you're super familiar with Python, you'd think that Load and Store cover everything that the language needs, but weirdly enough there's a third possible ctx value, Del:
Extras: why `del` exists
And just like Store, you can also Del an attribute or an index, and it behaves similarly:
Extras: type_ignores
Walking the Syntax Trees with Visitors
So now we know that our AST represents code using nested Nodes, a structure that is called a "tree". We also know that in a tree structure, a node can have as many children nodes inside it as needed. With all that, comes the question of how does one "read" the tree.
The most obvious way would be to read it top to bottom, the way it appears in the AST dump that we've seen so many times:
We have a Module, which has two nodes in its body: an Assign which has a Name and a Constant, and an Expr which has a Call with a couple Name nodes.
This way of reading from parent node to child node, in the sequence they appear, is called a pre-order traversal of the tree. And for most intents and purposes, it is what you need.
To implement this sort of traversal of an AST, Python provides you the NodeVisitor class, which you can use like this:
We'll look closer at the code inside MyVisitor very soon, but let's try and examine its output first.
This outputs the following:
To make this output slightly more detailed, let's not just print the class name, but the entire node:
There's a lot more output, but it might help to look at it:
entering Module(body=[Assign(targets=[Name(id='x', ctx=Store())], value=Constant(value=5)), Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='x', ctx=Load())], keywords=[]))], type_ignores=[])
entering Assign(targets=[Name(id='x', ctx=Store())], value=Constant(value=5))
entering Name(id='x', ctx=Store())
entering Store()
entering Constant(value=5)
entering Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='x', ctx=Load())], keywords=[]))
entering Call(func=Name(id='print', ctx=Load()), args=[Name(id='x', ctx=Load())], keywords=[])
entering Name(id='print', ctx=Load())
entering Load()
entering Name(id='x', ctx=Load())
entering Load()
You might notice that the first line output is the entire Module. The second line is the Assign node that is inside the Module, and the third line is a Name node which is inside that. Interesting.
With that, let's understand what's going on. You can imagine this "visitor" moving from up to down, left to right in this tree structure:
It starts from the Module, then for each child it visits, it visits the entirety of one of its children before going to the next child. As in, it visits the entirety of the Assign sub-tree before moving on to Expr part of the tree, and so on.
How it does this, is all hidden in our generic_visit() implementation. Let's start tweaking it to see what results we get. Here's a simpler example:
Now let's move the print statement to below the super call, see what happens:
Interesting. So now the prints suddenly happen in sort-of "reverse" order. It's not actually reversed though, but now every child appears before the parent. This bottom-to-top, left-to-right traversal is called post-order traversal.
So how about if we do both prints together?
If you follow the enter and leave commands one by one, you'll see how this traversal is happening. I've added the corresponding line numbers for each node in an [enter, leave] pair in this graph, and you can follow the traversal from 1 through 10:
You can keep this in mind, that anything that comes before the super() call is being done in pre-order, and anything that comes after the super() call is being done in post-order.
So let's say that for some reason I wanted to find how many statements exist inside all the for loops in my code. To do that, I'd need to do the following:
- Traverse through the code to find all For nodes. We're already sorted with that.
- Each time we see a For node, we need to start our count from zero.
- We must keep counting until we see this same For node again during post-order traversal.
- For every node we find below the For node, check if it's a statement, and increment count.
The code for that would look like this:
And this is the output:
For node contains 1 statements
For node contains 3 statements
The power of AST manipulation
The real power of ASTs comes from the fact that you can edit an AST, and then compile and run it, to modify your source code's behaviour at runtime.
To get into all of that, let me explain a little bit about how Python runs your source code:
- The first step is to parse the source code. This actually involves two steps, converting the source code into tokens, and then converting the tokens into a valid AST. Python neatly exposes these two parts of the compilation step with the ast.parse function. You can examine the AST produced in the output above.
- The next step is to "compile" the given AST into a code object. Code objects are objects that contain compliled pieces of Python "bytecode", the variable names present in the bytecode, and the locations of each part in the actual source. The compile function takes in the AST, a file name (which we set to '<my ast>' in our case), and a mode, which we set to 'exec', which tells compile that we want an executable code object to come out.
- The third step is to run the exec() on this code object, which runs this bytecode in the interpreter. It provides the bytecode with all the values present in the local and global scopes, and lets the code run. In our case, this makes the object print out hello.
If you want a more in-depth explanation of this, I have an entire section on it in my builtins blog.
Now since we have the AST in step 1 of this part, we can simply modify the AST, before we run the compile and execute steps, and the output of the program will be different. How cool is that!
Let's just jump into it with a few simple examples, before we do something really awesome with this.
Let's write one that changes all numbers in our code to become 42, because why not. Here's our test code:
Running this, we get:
Now, if all numbers in this code were 42, our output would be:
The last 84 is from 42 + 42 (instead of 28 + 100).
So, how would you do this? It's quite straightforward actually. What we do is define a NodeTransformer class. The difference between this and a NodeVisitor is that a Transformer actually returns a node on every visit, which replaces the old node:
Let's run it and see our output:
And, as expected, the output is:
... wait. That's not right. Something's definitely not right.
And with that, we are going to talk about something that's really important when it comes to playing around with ASTs: Sometimes, it can be quite hard to get right.
A tutorial can make any topic seem easy and obvious, but often times it misses out on the whole learning process of making mistakes, discovering new edge cases, and actually understanding how to debug some of these things.
Okay, mini-rant over. Let's try and debug this thing. To do that, the first thing that we should do is head to the docs:
The first line below a code example says:
"Keep in mind that if the node you’re operating on has child nodes you must either transform the child nodes yourself or call the generic_visit() method for the node first."
Ah, yes, we forgot the super() call. But why does that matter?
The super() call is what propagates the tree traversal down the tree. If you don't call that, the visit method will stop on that specific node, and never visit the node's children. So, let's fix this:
Let's run it again:
Uh oh. Another error. Welp, back to the docs.
This is what the rest of the section on NodeTransformer says:
That's exactly what we need to do, run fix_missing_locations. Always read the entirety of the docs, folks. Let's run it again:
Finally! We were able to modify and run our AST 🎉
Let's go a little more in depth. This article is super long already, might as well add some more interesting stuff.
Since it's very common for AST modifications to deal with a specific kind of node, and nothing else (We've already seen a few examples where that would've been useful, such as turning every number into 42), The NodeVisitor and NodeTransformer classes both let you define Node-specific visitor methods.
You define a node-specific visitor method by defining a visit_<NodeName> method, just as visit_For to just visit for-loops. If the visitor encounters a for loop, it will first see if a visit_For is defined in the class, and run that. If there isn't, it runs generic_visit as the fallback.
Here's a somewhat wacky example, which lets you run a program that outputs to the terminal, and make it so that it outputs to a file instead. To do that, we're going to rewrite every print call, and add an attribute, file=..., which will make it print to that file instead.
Let's see what the AST looks like, for a print() call with and without a file=... attribute.
So we need to find every Call with the func attribute being Name(id='print'), and add a file property to the Call's keywords.
Time to test this:
If you want to, a nice exercise for checking your understanding would be to re-write every generic_visit based code we have written so far, and simplify it using visit_X methods instead.
Let's build: A simple linter
We've learned all the key components for this, all that's left to do is to put everything together. Let's write our own linter from scratch.
Here's the idea:
- We're gonna make a Linter class, which holds our "lint rules".
- Lint rules are the actual checks that run on the code. They comprise of 3 things:
- A rule "code" that uniquely identifies it,
- A message that explains the rule violation to the user,
- And a Checker class, which is an AST visitor that checks which nodes violate this rule in the source code.
- Our linter class will register these rules, run them on a file, and print out all violations.
So let's write down our linter framework, mylint.py:
Sweet. Now that we have a framework, we can start writing our own checkers. Let's start with a simple one, one that checks if a set has duplicate items:
The only thing left, is to write a main function, that takes the filenames from the command line, and runs the linter:
That was a lot of code, but hopefully you were able to make sense of all of it. Alright, time to lint some files. I wrote a test.py file to test our linter:
Let's run:
We've successfully written a linter!
The real fun starts though, with the really intricate lint rules that you can write. So let's write one of those. Let's try to write a checker that checks for unused variables.
Here's the idea: Every function has a local scope, where you can define new variables. These variables only exist within that function, they have to be created and used all within that function. But, if a variable is defined but isn't used in a function, that's an unused variable.
To detect unused variables, we can visit all Name nodes inside a function or class, and see if there's any in Load context. If a variable is only ever present in Store context, that means it's defined but never used. So let's do that:
There's just one caveat: We can't just use this visitor on our entire file, as it won't find all unused variables. Here's a quick code example to demonstrate:
Here, we would see var being used in the global scope. Because of that, the checker won't catch the unused var inside func(), if we only run it on the entire file. We actually want to run this checker on every single function and class in the file.
So that's exactly what we are going to do. We will write a checker that runs this checker inside it. Yeah, my mind was also blown when I realised I can run visitors inside visitors. But here's the plan: For every ClassDef and FunctionDef, we will run the above checker on them to find unused local variables, and we will also run it on the Module to find globally unused variables. Sounds good?
All that's left now, is to enable this checker in the main function:
I've also prepared a new test file, with some more variables:
Alright, let's run it!
It works! We did it, we've built our own linter 🥳
If you've made it all the way to this part of the post, congratulations. Even though I've tried my best, I wouldn't be surprised if people still find this article too hard to follow. But, that's just the nature of code analysis -- it's hard. It took me 6 months of working with Python's AST every single day to be able to write this blog, so if you've appreciated the work I'd love to hear about it 🙌
AST utilities
The AST module gives you a few other useful utility classes:
- ast.literal_eval can be used as a safer alternative to eval, as literal_eval doesn't access the Python environment, and can only parse literal values, like 'abc' or "{'name': 'Mike', 'age': 25}". But sometimes, that's all you need to evaluate, and it's infinitely safer than using eval for that.
- ast.unparse was added in Python 3.9, and can take an AST node and convert it back to source code.
- Note that ASTs don't store all the information from a source file, so things like exact whitespace information, comments, use of single/double quotes etc. will be lost when you do this. But it still works:
- Re-parsing an un-parsed AST again should generate the same AST.
- ast.walk takes in an AST node and returns a generator, that yields all of its children one by one, in no specific order.
What about code formatters?
A little while ago I mentioned that ASTs don't contain certain information like whitespaces, comments, and quote styles. This is because ASTs are meant to be a representation of what the code means, not how the code looks. Anything that Python won't need to store to run the code, is stripped out.
Turns out, this is a huge problem if you want to create things like code formatters, which need to be able to produce the exact source code that was used to build the syntax trees. If you don't, everytime the code formatter runs on your file, the whole file will look completely different from how you wrote the code.
For such use cases, we need what is called a Concrete Syntax Tree, or CST. A CST is essentially just an AST which contains style information as well, such as where the newlines are, how many spaces are used for indentation, and so on.
If you need to build a code editor, formatter, or something of that sort for Python, I'll highly recommend libcst. It's the best CST library for Python that I know of.
Where can I learn more?
If you somehow still want to learn more about ASTs, then my first recommendation would be to read all of the documentation of the ast module. Furthermore, you should also check out greentreesnakes, which was the original source of most of the code examples in the official tutorial today. There's a lot of material to read there.
One of the craziest parts about the ast module is that, even with all the amazing things it lets us do, the entire source code ast.py is just 1700 lines of Python code. It's a definite must-read if you want to dive deeper into ASTs.
The linter that I wrote can be found on my github. It can be thought of as an an extremely simplified version of pylint, one of the most popular linters in Python, which also has its own AST wrapper called astroid.
And with that, you've reached the end of the article. I hope you find good use of this.