CCCMKホールディングス TECH LABの Tech Blog

TECH LABのエンジニアが技術情報を発信しています

ブログタイトル

scikit-learn準拠モデル作成に挑戦

はじめに

はじめまして、技術開発ユニットの伊藤です。 4月に入社した新卒1年目エンジニア?です。

4月中は木を作成するのがマイブームで、 趣味で決定木やExtra Treeを実装していました。 自作Extra Treeをラップし、 scikit-learn準拠モデルにした際にハマった点について書きます。 (Extra Treeはscikit-learnに実装されているので自作する意味はほぼ無いです。)

scikit-learn準拠モデルの作成にあたり、以下の記事を参考にしました。

sklearn準拠モデルの作り方

環境

使用したpython, ライブラリのバージョンは以下のとおりです。

  • Python 3.8.2
  • numpy 1.18.1
  • sklearn 0.22.1

ラップ前の自作モデル

自作Extra TreeのクラスNodeは 木を作成するメソッドbuild()、 予測(平均値)を出力するメソッドpredict()を実装しました。 Nodeクラスは、分類ではbuild()に 目的変数のone-hot表現を入力する仕様です。 Nodeクラスはこのままではscikit-learn準拠ではないので、 ラップしてscikit-learn準拠にします。

scikit-learn準拠か確認

sklearn.utils.estimator_checks.check_estimator()で 自作推定器がscikit-learn準拠になっているか確認できるようです。 例えば、分類器ExtraTreeClassifierをテストしたいときは以下のように書きます。

check_estimator(ExtraTreeClassifier())

BaseEstimatorの実装

まず、分類、回帰で共通する部分を基底クラスExtraTreeで実装します。

class ExtraTree(BaseEstimator):
    def __init__(self, max_depth=10, random_state=None):
        self.max_depth = max_depth
        self.random_state = random_state

    def fit(self, X, y):
        if not hasattr(self, 'root_'):
            self.root_ = Node()
            self.n_features_ = X.shape[1]
        random_state = check_random_state(self.random_state)
        
        self.root_.build(X, y, self.max_depth, random_state)

        return self

    def predict(self, X):
        X = check_array(X)

        check_is_fitted(self, 'root_')

        if X.shape[1] != self.n_features_:
            raise ValueError()

        return self.root_.predict(X)
  1. メンバ変数の位置

    以下を行うとcheck_estimator()でエラーになりました。

    • サフィックスとして"_"をつけたメンバ変数を__init__()で定義する
    • サフィックスとして"_"をつけていないメンバ変数をfit()で定義する
  2. random_state

    sklearn.utils.check_random_state()を使用すると簡単にrandom_stateに対応できました。

  3. 入力データの確認

    sklearn.utils.check_array(), sklearn.utils.check_X_y()でlistの入力に対応できます。またnanを許容するかなどの設定も行えるようです。 ただしfit()で入力されたXとpredict()で入力されたXの列数が同じかは判定してくれないので自分で判定する必要がありました。

  4. その他

    fit()でreturn selfする必要があると思っていましたが、return selfをコメントアウトしてもcheck_estimator()でエラーになりませんでした。

Classifierの実装

基底クラスExtraTreeを継承し分類器を実装します。

class ExtraTreeClassifier(ExtraTree, ClassifierMixin):
    def fit(self, X, y):
        X, y = check_X_y(X, y)
        check_classification_targets(y)

        # one-hot encoding
        self.classes_, inverse = np.unique(y, return_inverse=True)
        y = np.identity(self.classes_.shape[0])[inverse]

        return super().fit(X, y)

    def predict_proba(self, X):
        return super().predict(X)

    def predict(self, X):
        proba = self.predict_proba(X)
        return self.classes_[proba.argmax(axis=1)]
  1. check_classification_targets()を使う

    check_estimators()では分類器に対し、連続量の目的変数を入力したときに適切に処理するかテストされます。sklearn.utils.multiclass.check_classification_targets()で連続量の目的変数に対し、ValueErrorしてくれます。

  2. 目的変数を変換する

    自作Extra Treeでは分類のとき目的変数をone-hot encodingするようにしています。check_estimators()で文字列のベクトルの入力に対処する必要がありました。np.unique()などを使用することで対処しました。

Regressorの実装

class ExtraTreeRegressor(MultiOutputMixin, ExtraTree, RegressorMixin):
    def fit(self, X, y):
        X, y = check_X_y(X, y, multi_output=True)

        if y.ndim < 2:
            y = y.reshape(-1, 1)

        return super().fit(X, y)

    def predict(self, X):
        y = super().predict(X)

        if y.shape[1] == 1:
            y = y.flatten()

        return y
  1. MultiOutput回帰に対応させる

    sklearn.base.MultiOutputMixinを継承、check_X_y(multi_output=True)とすることでMultiOutput回帰に対応させました。

おわりに

上記のような修正を行った結果、 無事にscikit-learn準拠モデルの作成ができました! これで自作推定器にscikit-learnのcross validationやbaggingを適用できると思います。