RuntimeError: CUDA error: Device-side assert triggered の解決方法

· 3min · Masataka Kashiwagi

Pytorch でモデルを作成していた際に,「RuntimeError: CUDA error: device-side assert triggered」が発生し,原因がよくわからなかったので,調べたことをメモしておく

エラー発生の原因

調べてみると,原因としては以下のようなものがある.

  • ライブラリの Version が違う
  • ラベル/クラスの数とネットワークの入出力の shape が異なる
  • Loss 関数の入力が正確でない

などなど...

よくあるのが,下2つかなと思う.

ラベル/クラスの数とネットワークの入出力の shape が異なる

想定しているラベルもしくはクラス数とネットワークの出力のクラス数が異なる場合,この場合は FC 層の最後に nn.Linear(input, num_class) を入れて調整する必要がある.

Loss 関数の入力が正確でない

僕が遭遇したのはこちらのパターンになる.

例えば,BCELoss を考えた場合,計算するためには値としては0~1を取る必要がある.そのため普通は最終出力に Sigmoid関数 or Softmax関数 を入れる.

それ以外にも Loss の設計で以下のようにしておくと良い.

class BCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCELoss()

    def forward(self, input, target):
        input = torch.where(torch.isnan(input), torch.zeros_like(input), input)
        input = torch.where(torch.isinf(input), torch.zeros_like(input), input)
        input = torch.where(input>1, torch.ones_like(input), input)  # 1を超える場合には1にする

        target = target.float()

        return self.bce(input, target)

他の解決方法

他にも調べていると解決方法として CUDA の設定を以下にすると良いなどもあったが,解決するかどうかはよくわからない.

CUDA_LAUNCH_BLOCKING=1

今回は,Pytorch でのモデル作成時に発生したエラーについて整理した,モデル作成時にはモデルの In/Out や Loss 関数の定義をきちんと理解し把握しておく必要があると改めて感じた.同様のエラーが起きた場合には,この辺りをまずは調べてみるのが良さそう.

参考


このエントリーをはてなブックマークに追加

ブログ記事を読んで頂き,ありがとうございます!もしこの記事が良かったり参考になったら,「Buy me a coffee」ボタンから☕一杯をサポートして頂けるとモチベーションが上がります!どうぞよろしくお願いします🤩