Extending Python with Pyrex
In a previous entry, I showed how to extend Python without directly using the Python C using SWIG. SWIG parses an interface file written by the programmer and then outputs a wrapper module. A more elegant, more pythonic solution might be a parser which allows a C module to be written directly in Python. Pyrex lets you write your extension module in a Python-like syntax that compiles to C. The C module can then by compiled using distutils (or gcc, et al) to a binary Python module.
Pyrex has several advantages over brute force solutions like SWIG. Pyrex makes the creation of new built-in data types a walk in the park. There are no cryptic functions for dealing with structs or unions. Perhaps most significantly, though, is the fact that modules written for Pyrex are written in a syntax very close to Python. In fact, according to the Pyrex website, “the fundamental nature of Pyrex can be summed up as follows: Pyrex is Python with C data types.”
Python is simple, elegant, and productive. That is why Pythonistas use Python instead of C. The one place where C is required is when Python lacks a particular functionality that would be too expensive if implemented in pure Python. Pyrex solves this by allowing the implementation to be written in Python, but then compiled to C (which can then be compiled to machine code.)
In the previous article on SWIG and Python, I wrote a simple C extension that counted the occurrences of each ASCII character in a string. The reason for this is that iteration over Python’s strings is slow. So let’s create a new data type using Pyrex, the CharIterator. Our type will be constructed from a Python string and be maintained internally as a character array. The type will implement the Python iterator and return the ASCII code for the current index’s character. Here is our simple Pyrex implementation, which we will call mymod.pyx:
cdef class CharIterator: cdef char *value cdef int index cdef int last_char def __init__(self, char *value): self.value = value self.index = -1 self.last_char = len(value)-1 def __iter__(self): return self def __next__(self): if self.index <= self.last_char: raise StopIteration self.index = self.index + 1 return self.value[self.index]
CharIterator maintains state in self.index, which is one below the current 0-base index. The reason for this is so that the character being returned from __next__ does not need to be stored in a temporary variable while index is incremented. The other possibility would be to return self.value[self.index-1], but that would be inelegant and inefficient (although inc/dec operations have minimal overhead). More information about the specifics of Pyrex syntax (such as cdef and when to use static type declarations) may be found in the Pyrex documentation.
Now, compile the module. Write a simple setup.py:
from distutils.core import setup from Pyrex.Distutils.extension import Extension from Pyrex.Distutils import build_ext setup( name = 'mymod', ext_modules=[Extension("mymod", ["mymod.pyx"])], cmdclass = {'build_ext': build_ext} )
…and run “python setup.py build”, then “python setup.py install” to install to the Python site-packages directory. You can also skip the last command and manually copy mymod.o from the build/lib.path_based_on_os_and_version/ directory. Now test the module:
from mymod import CharIterator text = "Hello world." chars = CharIterator(text) for c in chars: print c, "=", chr(c) # Output: # 72 = H # 101 = e # 108 = l # 108 = l # 111 = o # 32 = # 119 = w # 111 = o # 114 = r # 108 = l # 100 = d # 46 = .
A little testing shows that our CharIterator is a little under twice as fast (roughly 40% on my machine) than iterating over a string array, counting occurrences of each character. The code used for the test was:
from time import clock from mymod import CharIterator def string_of_file(file_name): file_handle = open(file_name) file_text = file_handle.read() file_handle.close() return file_text def string_of_int(i): if i > 32: return chr(i) elif i == 9: return " Tab" elif i == 10: return " Newline" elif i == 13: return " Carriage return" elif i == 32: return " Space" else: return " Not counted" def count_chars_pyrex(text): chars = CharIterator(file_text) ascii = range(0, 128) counts = [0 for i in ascii] for c in chars: counts[c] += 1 return convert_list(counts) def count_chars_native(text): ascii = range(0, 128) counts = [0 for i in ascii] for c in text: i = ord(c) counts[i] += 1 return convert_list(counts) def convert_list(lst): ascii = range(0, 128) counts = {} for i in ascii: char = string_of_int(i) if char in counts.keys(): counts[char] += lst[i] else: counts[char] = lst[i] return counts def print_counts(counts): chars = counts.keys() chars.sort() for k in chars: v = counts[k] if v > 0: print "%s: %d" % (k.lstrip(), v) if __name__ == '__main__': file_name = "/Users/jober/Desktop/Projects/war_and_peace.txt" file_text = string_of_file(file_name) start = clock() native_counts = count_chars_native(file_text) end = clock() native_time = end-start print "Native implementation ran in %f seconds." % native_time start = clock() pyrex_counts = count_chars_pyrex(file_text) end = clock() pyrex_time = end-start print "Pyrex implementation ran in %f seconds." % pyrex_time print "----------" print "Subject text is %d characters long." % len(file_text) if pyrex_time < native_time: print "Pyrex implementation was faster by %f seconds." % (native_time - pyrex_time) else: print "Native implementation was faster by %f seconds." % (pyrex_time - native_time) print "----------" print_counts(pyrex_counts)
More information about Pyrex:
- Pyrex homepage
- Documentation
- Tutorial (http://ldots.org/pyrex-guide/ is apparently gone for good)