ふるお〜と!- FullAuto

AI・ロボットが普及しBI(ベーシックインカム)が早急に実現されることを願う元ニートのブログ

ふるお〜と!-FullAuto

【ふるお〜と!】PyTorch実装の流れ

PyTorch | Pytorch3Dで使われる関数

No PyTorch PyTorch3D
1 Setting
Define
torch.cuda.is_available()
torch.manual_seed(.)
torch.device(...)
torch.utils.data.DataLoader(..)
torch.nn.Module
torch.cuda.is_available()
torch.device(.)
torch.nn.Module
torch.nn.Parameter(...)
torch.sigmoid(...)
2 Train optimizer.step() torch.optim.Adam(...)
torch.randperm(...)
optimizer.step()
3 Test torch.no_grad() torch.no_grad()
4 Save torch.save(.) ?

Sample:Mnistのmain.py

f:id:nullpo24:20210911123651p:plain

4.Saveについて

torch.save(model.state_dict(), "mnist_cnn.pt")

拡張子の.ptってなんだろう?

pytorchの略ね。他にも.pthとか実装者によりけりだわ。

このmodel(Tensor)の作り方もさまざまね。 pytorch3Dでは

deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)

class VolumeModel(torch.nn.Module):
...

torch.nn.Moduleを継承して

volume_model = VolumeModel(
    renderer,
    volume_size=[volume_size] * 3, 
    voxel_size = volume_extent_world / volume_size,
).to(device)

のようにしてモデルが作られているわ。

ゆえに Tensorを保存するには以下のようにすれば良い

torch.save(volume_model.state_dict(), "banana.pt")

PyTorch実装の流れ

No.1のSettingとDefineってちょっとざっくりしすぎじゃないですか?

<PyTorch実装の流れ>
1.前処理、後処理、そしてネットワークモデルの入出力を確認
2.Datasetの作成
3.DataLoaderの作成
4.ネットワークモデルの作成
5.順伝搬(forward)の定義
6.損失関数の定義
7.最適化手法の設定
8.学習・検証の実施
9.テストデータで推論
※1

上記がPyTorch実装の流れね。 No.1の設定と定義、すなわち上の1~7は天才が設計するとして、あなた達のような凡人はデータ収集(annotator)と8.学習(training)・9.推論(inference)に注力したらいかがかしら?

データをぶち込むだけで誰でも扱えるようなサービスがそのうちできるようになることが予測されているわ

というわりにはまだ出てこないような・・・ とりあえず、現段階では学習済みモデルを使用して推論をするコードの作り方くらいは知っておく必要がありそうですね

参考

※1
つくりながら学ぶ!PyTorchによる発展ディープラーニング - 小川雄太郎 - Google ブックス