Build own iterator over enumerate

Posted on Mon, 27 Jun 2016 in Python

Couple days ago one of my colleagues asked me to implement simple self-logging iterator inherited from enumerate. I tried to inherit directly from it and failed. I absolutely forgot how to use __new__ magic method. It was so embarrassing. I was in a hurry, so I promised myself to solve this puzzle. And I can tell you it is a simple task. Actually it takes 18 lines of code only.

Initial task

First of all I have to explain why we need such iterator, why standard enumerate isn't enough.

In our project we had a lot of console tasks with similar structure:

  1. Get bunch of object from DB
  2. Log how many objects you have
  3. Do something with each object from the bunch, log progress every X objects
  4. Log after finish

In python it looked like this:

iterable = get_bunch()
total = len(iterable)

print("total: {}".format(total))

for i, item in enumerate(iterable, start=1):
    try:
        func(item)
    except Exception as e:
        print("catch exception: {}".format(e))

    if not i % 100:
        print("done {} of {}".format(i, total))

print("Done!")

Difference between tasks was only in log messages and func function. It looked like copy-paste code. So we decided to refactor it.

Implementation

Let's try to make class inherited from enumerate. As I mentioned above, we have to override __new__ method because enumerate does it. You should remember that in accordance to documentation if __new__() returns an instance of cls, then the new instance’s __init__() method will be invoked with the same arguments.

So our implementation looks like these:

class LogEnumerate(enumerate):
    def __new__(cls, iterable, start=1, *args, **kwargs):
        return super(LogEnumerate, cls).__new__(cls, iterable, start)
    def __init__(self, iterable, start=1, step=10,
                 start_message='', progress_message='', stop_message=''):
        self.progress_message = progress_message
        self.stop_message = stop_message
        self.step = step
        self.total = len(iterable)
        print(start_message.format(start_message))
    def __next__(self):
        try:
            i, item = super().__next__()
            if not i % self.step:
                print(self.progress_message.format(i, self.total))
            return item
        except StopIteration:
            print(self.stop_message)
            raise