A decorator that got class

What if I told you there was not just one type of decorator in Python, but two! First things first, what is a decorator? Then what two types are there? Without further ado, let us dive into the world of decorators.

Decorator

The concept of a decorator in Python is nothing more than applying a generic/uniform transformation to functions by adding functionality to them and then returning a new function. For example, you might always want to have access to a certain argument or you want to propagate arguments across function calls. Decorators will help you with that.


def decorator(fn):
    def wrapper(*args, **kwargs):
        print(f"Going to run function: {fn} with following arguments: {args} and {kwargs}")
        return fn(*args, **kwargs)
    return wrapper

@decorator
def add(a,b):
   return a + b

add(1,b=2)
add(2,3)

This is a simple example. The syntax for the decorator itself is @ plus the name of the previously defined function. The value of fn is what is given to decorator the first time as one of the arguments. If you want to add state you add a level of nesting like the following:


def decorator(n):
    def params(fn):
        def wrapper(*args, **kwargs):
            print(f"Going to run function: {fn} with following arguments: {args} and {kwargs}")
            print(f"Value of n was {n}")
            return fn(*args, **kwargs)
        return wrapper
    return params

@decorator(10)
def add(a,b):
   return a + b

print(add(1,b=2))
print(add(2,3))

Types of decorators

Function based

The first of the types of decorators you already met, it is the function based one. You define functions and with that you do the modifications. The only problem with this one is that state is hard to manage and it is super implicit. So you might have a decorator inside a class instance and all of a sudden you can all self and such, but where did it come from?

Class based

This is the new quirky one, you can make a callable class in Python by doing the following:


class A:
    def __call__(self):
        print("Called")

a = A()
a()

Now you can call the class like a function, and thus opens the door to being used as a decorator. Now we get full access to class based properties, like static methods, class methods, inheritance, state and the like.

Example

The following example demonstrates some of the ideas. First the full code listing and then a brief highlight of what each thing is and what it does.


from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Callable 
import re


@dataclass
class Fixture:
    name: str
    value: Any


user_fixture = Fixture("Person", {"name": "Stealthy", "email": "stealthy@stealthy.com"})
message_fixture = Fixture("Message", {"data": {"name": "Stealthy"}, "msg": "Account created", "code": 201})

class WithFixtures:
    def __init__(self, *fixtures: Fixture):
        self.fixtures = fixtures
        self.init = False

    def __call__(self, *args, **kwargs):
        if not self.init:
            self.test_func = args[0]
            self.init = True
            return self
        return self.re_entrant(*args, *kwargs)

    def re_entrant(self, *args, **kwargs):
        if isinstance(self.test_instance, (UserTest,)):
            self.test_instance.validationTest(self.fixtures[0][0].value)
            return self.test_func(self.test_instance, *self.fixtures, *args, **kwargs)
        elif hasattr(self.test_instance, "__enter__"):
            self.test_instance.fixtures = [fixture for fixture in list(*self.fixtures)]
            with self.test_instance as instance:
                self.test_func(instance, *self.fixtures, *args, **kwargs)
        else:
            return self.test_func(self.test_instance, *self.fixtures, *args, **kwargs)

class WithInjectedMethod:
    def __init__(self, method_name: str, func: Callable):
        self.method_name = method_name
        self.injected_method = func
        self.init = False

    def __call__(self, *args, **kwargs):
        if not self.init:
            self.test_func = args[0]
            self.init = True
            return self
        return self.re_entrant(*args, *kwargs)

    def re_entrant(self, *args, **kwargs):
        setattr(self.test_instance, self.method_name, self.injected_method)
        return self.test_func(self.test_instance, *args, **kwargs)

class Test():
    def __new__(cls):
        TRAIT_NAMES = [x.__name__ for x in [WithFixtures, WithInjectedMethod]]
        instance = super(Test, cls).__new__(cls)
        for prop_def in instance.__class__.__dict__.values():
            if hasattr(prop_def, "__call__"):
                if list(filter(lambda x: x in str(prop_def), TRAIT_NAMES)):
                    prop_def.test_instance = instance

        return instance

    @abstractmethod
    def setUp(self) -> None:
        raise NotImplementedError

    @abstractmethod
    def tearDown(self) -> None:
        raise NotImplementedError
    

class UserTest(Test):
    @WithInjectedMethod("log", lambda x: print(x))
    def validationTest(self, user_fixture):
       if not "name" in user_fixture:
           raise ValueError("name missing from fixture")
       if not "email" in user_fixture:
           raise ValueError("email missing from fixture")
       self.log("Verification successful")

    @WithFixtures([user_fixture])
    def verify_name(self, *args, **kwargs):
        name_regex = re.compile(r"^[a-z].*$", re.I)
        if not name_regex.match(args[0][0].value['name']):
           raise Exception("verify_name failed")
        self.log("Name verified")

    @WithFixtures([user_fixture, message_fixture])
    def verify_email(self, *args, **kwargs):
       email_regex = re.compile(r"^[a-z].*@[a-z].*\.[a-z]\.?[a-z]*$", re.I)
       if not email_regex.match(args[0][0].value['email']):
           raise Exception("verify_email failed")

class MessageTest(Test):
    
    def setUp(self):
        for fixture in self.fixtures:
            if "data" not in fixture.value:
                raise ValueError(f"Fixture is incorrect: {fixture.value}")
        # Open db connection 


    def tearDown(self):
        # Close db connection
        ...

    def __enter__(self):
        self.setUp()

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.tearDown()

    @WithFixtures([message_fixture])
    def verify_message(self, *args, **kwargs):
        if args[0][0].value['msg'] != "Account created":
            raise Exception("wrong message value")

t = UserTest()
t.verify_name()
t.verify_email()

t = MessageTest()
t.verify_message()

I chose a simple example of how you might write a simple test harness / framework using this particular technique.

Imports


from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Callable 
import re


@dataclass
class Fixture:
    name: str
    value: Any


user_fixture = Fixture("Person", {"name": "Stealthy", "email": "stealthy@stealthy.com"})
message_fixture = Fixture("Message", {"data": {"name": "Stealthy"}, "msg": "Account created", "code": 201})

Some simple imports we need to have. A dataclass that just holds two properties to ship around some data. Then some predefined usable fixtures with some static data that can be operated on. Like you might have in REST API communication with a third party service or with your own backend, and it is based on JSON messages being passed around.

WithFixtures

class WithFixtures:
    def __init__(self, *fixtures: Fixture):
        self.fixtures = fixtures
        self.init = False

    def __call__(self, *args, **kwargs):
        if not self.init:
            self.test_func = args[0]
            self.init = True
            return self
        return self.re_entrant(*args, *kwargs)

    def re_entrant(self, *args, **kwargs):
        if isinstance(self.test_instance, (UserTest,)):
            self.test_instance.validationTest(self.fixtures[0][0].value)
            return self.test_func(self.test_instance, *self.fixtures, *args, **kwargs)
        elif hasattr(self.test_instance, "__enter__"):
            self.test_instance.fixtures = [fixture for fixture in list(*self.fixtures)]
            with self.test_instance as instance:
                self.test_func(instance, *self.fixtures, *args, **kwargs)
        else:
            return self.test_func(self.test_instance, *self.fixtures, *args, **kwargs)

The first actual example. So there is a simple __init__ function like normal in classes. Then the magix happens in __call__. There is a switch in there to make it so the first time it runs, it will setup the function target and possible other things and to prime it. Then the next time it will only ever execute re_entrant which contains the actual logic of the class method on what to do with the function, how to transform it or add functionality to it that is reusable. In this case if it is a UserTest then it will always run the defined validationTest method and following that the original target.

If it has a __enter__ method it means it can be a context manager and therefor it will be run with a with statement as well as transferring the fixtures to the target_instance.

Lastly it will just run the target function and injecting the fixtures as function arguments.

WithInjectedMethod

class WithInjectedMethod:
    def __init__(self, method_name: str, func: Callable):
        self.method_name = method_name
        self.injected_method = func
        self.init = False

    def __call__(self, *args, **kwargs):
        if not self.init:
            self.test_func = args[0]
            self.init = True
            return self
        return self.re_entrant(*args, *kwargs)

    def re_entrant(self, *args, **kwargs):
        setattr(self.test_instance, self.method_name, self.injected_method)
        return self.test_func(self.test_instance, *args, **kwargs)

Another example of what we might be able to do. This will only add a predefined method (or lambda) to the target instance. Then run the target method.

Base Test class

class Test():
    def __new__(cls):
        TRAIT_NAMES = [x.__name__ for x in [WithFixtures, WithInjectedMethod]]
        instance = super(Test, cls).__new__(cls)
        for prop_def in instance.__class__.__dict__.values():
            if hasattr(prop_def, "__call__"):
                if list(filter(lambda x: x in str(prop_def), TRAIT_NAMES)):
                    prop_def.test_instance = instance

        return instance

    @abstractmethod
    def setUp(self) -> None:
        raise NotImplementedError

    @abstractmethod
    def tearDown(self) -> None:
        raise NotImplementedError

This showcases maybe a new thing in the __new__ method. That is not nearly showcased as often. It plays a role in how class get instantiated. First the __new__ method is called and it calls the __init__ method. This is exactly what I want, I need an instance and then I will modify that instance based on if any of the properties defined (methods) will have a decorator of some sort in the list of names I specify there. It first filters all callables and then it will check if it is actually one of either WithFixtures or WithInjectedMethod. It will then set the test_instance property on the class instance of either one of those classes.

The caveat here is that this current setup does not allow to stack multiple decorators

Then there is an optional setUp and/or tearDown method.

UserTest class

class UserTest(Test):
    @WithInjectedMethod("log", lambda x: print(x))
    def validationTest(self, user_fixture):
       if not "name" in user_fixture:
           raise ValueError("name missing from fixture")
       if not "email" in user_fixture:
           raise ValueError("email missing from fixture")
       self.log("Verification successful")

    @WithFixtures([user_fixture])
    def verify_name(self, *args, **kwargs):
        name_regex = re.compile(r"^[a-z].*$", re.I)
        if not name_regex.match(args[0][0].value['name']):
           raise Exception("verify_name failed")
        self.log("Name verified")

    @WithFixtures([user_fixture, message_fixture])
    def verify_email(self, *args, **kwargs):
       email_regex = re.compile(r"^[a-z].*@[a-z].*\.[a-z]\.?[a-z]*$", re.I)
       if not email_regex.match(args[0][0].value['email']):
           raise Exception("verify_email failed")

One of the actual Test subclasses. It features how to call the different decorators. Showcasing you can apply multiple fixtures inside a list, or just give them separately. As you can see the log method gets persisted across runs, since it is tied to the instance itself.

MessageTest class

class MessageTest(Test):
    
    def setUp(self):
        for fixture in self.fixtures:
            if "data" not in fixture.value:
                raise ValueError(f"Fixture is incorrect: {fixture.value}")
        # Open db connection 


    def tearDown(self):
        # Close db connection
        ...

    def __enter__(self):
        self.setUp()

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.tearDown()

    @WithFixtures([message_fixture])
    def verify_message(self, *args, **kwargs):
        if args[0][0].value['msg'] != "Account created":
            raise Exception("wrong message value")

This class showcases how to make a context manager out of it and how to add properties to instances.

Running the tests


t = UserTest()
t.verify_name()
t.verify_email()

t = MessageTest()
t.verify_message()

This just creates the instances of the classes and runs the methods.

Conclusion

I hope this showcases the idea of using classes as your template for decorators and how to dynamically inject them and therefor also how to add behaviour to class instances after instantiating them in a programmatic way.

This can be further improved with typehinting to make it even clearer what is happening. What kind of instances you are getting and what you are allowed to do. You could even write a Protocol to define some implicit interface and keep it even more generic.