デコレータでPyTorchのTensorと次元を監視する

はじめに

 PyTorchを使っていると入力が3次元固定であることがよくあります。そのため自作する関数もそれに合わせたりするのですが、使う側としては入力がTensorや3次元ではない時に教えてくれたほうが親切なわけです。しかしながら、作成済み関数に手を加えるのは気が引けるのと、関数が1つや2つならまだしも全ての関数にこの機能を追加するのは大変です。今回はそんなワガママに応えてくれるデコレータの話になります。

論よりコード

import functools  
import torch  


def _is_tensor3d(x):  
    return torch.is_tensor(x) and x.ndimension() == 3  


def is_tensor3d(func):  
    @functools.wraps(func)  
    def wrapper(x, *args, **kwargs):  
        if not _is_tensor3d(x):  
            raise TypeError('x is not a 3D tensor')  
        return func(x, *args, **kwargs)  
    return wrapper  


@is_tensor3d  
def zscore(x, axis):  
    x_mean = x.mean(dim=axis, keepdim=True)  
    x_std = x.std(dim=axis, keepdim=True)  
    return (x - x_mean) / x_std  

監視用の関数

まず、デコレータ内で使う監視用の関数を作ります。

import torch  

def _is_tensor3d(x):  
    return torch.is_tensor(x) and x.ndimension() == 3  

これはTensorかつ3次元ならTrueが返ってくる関数です。

デコレータ

次にデコレータ を作成します。

import functools  

def is_tensor3d(func):  
    @functools.wraps(func)  
    def wrapper(x, *args, **kwargs):  
        if not _is_tensor3d(x):  
            raise TypeError("x is not a 3D tensor")  
        return func(x, *args, **kwargs)  
    return wrapper  

入力xが3次元のTensorではない時はエラーを出すようにしておきます。

関数内部にあるfunctools.wrapsですが、これを使うことで装飾した関数の引数やdocstringが隠れてしまうことを防げます。
例えば、Jupyter Notebookだとfunctools.wrapsを使わない場合、

関数の引数が(x, *args, **kwargs)と表示され、使う側はaxisがあることがわかりません。

functools.wrapsを装飾することで、引数(x, axis)と正しく表示されるようになります。

装飾したい関数

@is_tensor3d  
def zscore(x, axis):  
    x_mean = x.mean(dim=axis, keepdim=True)  
    x_std = x.std(dim=axis, keepdim=True)  
    return (x - x_mean) / x_std  

軸方向で標準化する関数にis_tensor3dを装飾しておきます。先ほどのJupyter Notebookの例はこの関数での話でした。

実験

正しく動くかどうか試してみましょう。

import numpy as np  

a = torch.rand((15, 5, 20))  # 3次元のTensor  
print(zscore(a, (0, 1, 2)).std((0, 1, 2)))  # tensor(1.)  

b = torch.rand((10, 20))  # 2次元のTensor  
print(zscore(b, 1))  # TypeError: x is not a torch 3D.  

c = np.random.rand(15, 5, 20)  # 3次元のndarray  
print(zscore(c, 0))  # TypeError: x is not a torch 3D.  

しっかりと3次元Tensorのみを通していることが確認できました。

注意点

デコレータのwrapperと装飾した関数の引数を合わせないと問題が起こります。試しにzscoreの引数をxからyに変えたzscore2で実験してみます。

@is_tensor3d  
def zscore2(y, axis):  
    y_mean = y.mean(dim=axis, keepdim=True)  
    y_std = y.std(dim=axis, keepdim=True)  
    return (y - y_mean) / y_std  

d = torch.rand((15, 5, 20))  
print(zscore2(d, axis=(0, 1, 2)))  # 実行できる  

print(zscore2(y=d, axis=(0, 1, 2)))  # キーワード引数yだと実行できない  
# TypeError: wrapper() missing 1 required positional argument: 'x'  

print(zscore2(x=d, axis=(0, 1, 2)))  # キーワード引数xだと実行できる  

zscore2の引数はyですが、デコレータの方はxなのでyに代入しても実行できません。使う側は装飾してるかは知らないため混乱の元になります。なので引数名は統一しましょう。

最後に

 今回はPyTorchのTensorとその次元を監視しましたが、ndarrayやlist、intなどにでも転用できそうです。また、入力時にndarrayからTensorに変えたり、何かを出力時にTensorに戻したりと、よく使う機能を元の関数を変えずに追加できそうですね。便利!

参考

Pythonのデコレータにはwrapsをつけるべきという覚え書き - Qiita