PyTorch | Pytorch3Dで使われる関数
No | PyTorch | PyTorch3D | ||
---|---|---|---|---|
1 | Setting Define |
torch.cuda.is_available() torch.manual_seed(.) torch.device(...) torch.utils.data.DataLoader(..) torch.nn.Module |
torch.cuda.is_available() torch.device(.) torch.nn.Module torch.nn.Parameter(...) torch.sigmoid(...) |
|
2 | Train | optimizer.step() | torch.optim.Adam(...) torch.randperm(...) optimizer.step() |
|
3 | Test | torch.no_grad() | torch.no_grad() | |
4 | Save | torch.save(.) | ? |
Sample:Mnistのmain.py
4.Saveについて
torch.save(model.state_dict(), "mnist_cnn.pt")
拡張子の.ptってなんだろう?
pytorchの略ね。他にも.pthとか実装者によりけりだわ。
このmodel(Tensor)の作り方もさまざまね。 pytorch3Dでは
deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)
や
class VolumeModel(torch.nn.Module): ...
torch.nn.Moduleを継承して
volume_model = VolumeModel( renderer, volume_size=[volume_size] * 3, voxel_size = volume_extent_world / volume_size, ).to(device)
のようにしてモデルが作られているわ。
ゆえに Tensorを保存するには以下のようにすれば良い
torch.save(volume_model.state_dict(), "banana.pt")
PyTorch実装の流れ
No.1のSettingとDefineってちょっとざっくりしすぎじゃないですか?
<PyTorch実装の流れ>
1.前処理、後処理、そしてネットワークモデルの入出力を確認
2.Datasetの作成
3.DataLoaderの作成
4.ネットワークモデルの作成
5.順伝搬(forward)の定義
6.損失関数の定義
7.最適化手法の設定
8.学習・検証の実施
9.テストデータで推論
※1
上記がPyTorch実装の流れね。 No.1の設定と定義、すなわち上の1~7は天才が設計するとして、あなた達のような凡人はデータ収集(annotator)と8.学習(training)・9.推論(inference)に注力したらいかがかしら?
データをぶち込むだけで誰でも扱えるようなサービスがそのうちできるようになることが予測されているわ
というわりにはまだ出てこないような・・・ とりあえず、現段階では学習済みモデルを使用して推論をするコードの作り方くらいは知っておく必要がありそうですね