使用したGPU | 学習に要した時間 |
---|---|
GeForce RTX2080 | 28時間56分24秒(104184秒) |
import datetime import numpy as np import os import sys import cv2 from PIL import Image from tqdm import tqdm import glob import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from torch.utils.data import Dataset, DataLoader # Pix2PixHDのコードのものを使う from options.train_options import TrainOptions from models.models import create_model # proxy os.environ["http_proxy"] = "http://ccproxyz.kanagawa-it.ac.jp:10080" os.environ["https_proxy"] = "http://ccproxyz.kanagawa-it.ac.jp:10080" # GPUを使うかどうか USE_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = torch.device(USE_DEVICE) # PyTorchの内部を決定論的に設定する torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 乱数を初期化する np.random.seed(0) torch.manual_seed(0) if not torch.cuda.is_available(): # Pix2PixHDのオプションを上書き sys.argv.append('--gpu_ids') sys.argv.append('-1') # 画像のサイズ IMG_SIZE = 64 class MyTransform: def __init__(self): # イラストのtransform self.trans = transforms.Compose([ transforms.RandomApply([ transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), ], 0.5), transforms.RandomGrayscale(0.1), transforms.ToTensor(), ]) def __call__(self, img): return self.trans(img) class MyDataset(Dataset): def __init__(self, transform): self.transform = transform # ファイルの一覧 self.filelist = [f for f in os.listdir('cropped/') if os.path.getsize('cropped/'+f) > 0] def __getitem__(self, idx): # イラスト画像を読み込む img = Image.open('cropped/' + self.filelist[idx]) img = img.resize((IMG_SIZE, IMG_SIZE)) # 線画を作成する threshold = np.random.randint(300)+400 # 閾値をランダムに選択する glimg = img.convert('L') # 白黒画像へ cvimg = np.array(glimg, dtype=np.uint8) # OpenCVの画像へ cvimg = cv2.Canny(cvimg, threshold, threshold) # 線画へ cvimg = (cvimg > 0).astype(np.int32) # 0か1かにする cvimg = cvimg.reshape((1, IMG_SIZE, IMG_SIZE)) # 線画とイラストのペアを返す return torch.tensor(cvimg), self.transform(img) def __len__(self): return len(self.filelist) # Pix2PixHDのデフォルトオプションを上書きする opt = TrainOptions().parse() opt.name = 'line2illust' # 保存するモデルのディレクトリ名 opt.label_nc = 2 # 入力は2値のラベル opt.no_instance = True # featmapは使わない # モデルの保存場所を作る if not os.path.isdir('checkpoints/line2illust'): os.makedirs('checkpoints/line2illust') # Pix2PixHDのモデルを作る model = create_model(opt) model = model.to(device) # GPUを使うときはGPUメモリ上に乗せる model.train() # モデルを学習用に設定する # 学習時のバッチサイズ BATCH_SIZE = 4 # 大きい方が良いがGPUメモリに依存 # 学習エポック数 NUM_EPOCHS = 50 # データセットの読み込みクラス dataset = MyDataset(MyTransform()) # 別スレッドでデータを読み込む data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # OptimizerはPix2PixHDが用意してくれるものを使う optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D # 開始時刻を記録 start_t = datetime.datetime.now() # 学習ループ for epoch in range(NUM_EPOCHS): total_loss = [] # 各バッチ実行時の損失値 for X, y in tqdm(data_loader): # 画像を読み込んでtensorにする X = X.to(device) # GPUを使うときはGPUメモリ上に乗せる y = y.to(device) # GPUを使うときはGPUメモリ上に乗せる # モデルを実行する losses, generated = model(X, 0, y, 0) # GANの損失を求める losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] loss_dict = dict(zip(model.module.loss_names, losses)) loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) # G-Netを逆伝播する optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() # D-Netを逆伝播する optimizer_D.zero_grad() loss_D.backward() optimizer_D.step() # 全ての損失の平均値を求める loss = sum(losses) / len(losses) # 損失値を保存しておく total_loss.append(loss.detach().cpu().numpy()) # エポック終了時のスコアを求める total_loss = np.mean(total_loss) # 各バッチの損失の平均 # エポック終了時のスコアを表示する print(f'epoch #{epoch}: train_loss:{total_loss}') # 最終的なモデルを保存する model.module.save('chapt07-model.pth') # 終了時刻を記録 end_t = datetime.datetime.now() print(f'開始:{start_t}') print(f'終了:{end_t}')
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [35:20<00:00, 7.49it/s] epoch #0: train_loss:4.275062084197998 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:54<00:00, 7.59it/s] epoch #1: train_loss:4.163199424743652 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:50<00:00, 7.60it/s] epoch #2: train_loss:4.101843357086182 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [35:04<00:00, 7.55it/s] epoch #3: train_loss:4.066186428070068 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [35:05<00:00, 7.55it/s] epoch #4: train_loss:4.039588451385498 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:56<00:00, 7.58it/s] epoch #5: train_loss:4.015763759613037 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:44<00:00, 7.62it/s] epoch #6: train_loss:3.9875266551971436 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:20<00:00, 7.71it/s] epoch #7: train_loss:3.9640920162200928 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:38<00:00, 7.64it/s] epoch #8: train_loss:3.9381563663482666 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:32<00:00, 7.67it/s] epoch #9: train_loss:3.9110496044158936 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:36<00:00, 7.65it/s] epoch #10: train_loss:3.8827059268951416 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:42<00:00, 7.63it/s] epoch #11: train_loss:3.855252981185913 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:46<00:00, 7.62it/s] epoch #12: train_loss:3.8250529766082764 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:49<00:00, 7.61it/s] epoch #13: train_loss:3.8032658100128174 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:50<00:00, 7.60it/s] epoch #14: train_loss:3.774317502975464 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:51<00:00, 7.60it/s] epoch #15: train_loss:3.752943754196167 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:50<00:00, 7.60it/s] epoch #16: train_loss:3.731278419494629 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:46<00:00, 7.62it/s] epoch #17: train_loss:3.713496446609497 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:44<00:00, 7.62it/s] epoch #18: train_loss:3.6962809562683105 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:45<00:00, 7.62it/s] epoch #19: train_loss:3.677072763442993 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:30<00:00, 7.67it/s] epoch #20: train_loss:3.6649041175842285 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:41<00:00, 7.63it/s] epoch #21: train_loss:3.6497228145599365 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:29<00:00, 7.68it/s] epoch #22: train_loss:3.636552095413208 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:39<00:00, 7.64it/s] epoch #23: train_loss:3.618948221206665 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:32<00:00, 7.67it/s] epoch #24: train_loss:3.6085357666015625 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:31<00:00, 7.67it/s] epoch #25: train_loss:3.594952344894409 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:31<00:00, 7.67it/s] epoch #26: train_loss:3.58257794380188 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:44<00:00, 7.62it/s] epoch #27: train_loss:3.5724940299987793 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:33<00:00, 7.67it/s] epoch #28: train_loss:3.5590932369232178 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:40<00:00, 7.64it/s] epoch #29: train_loss:3.5533223152160645 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:32<00:00, 7.67it/s] epoch #30: train_loss:3.5394208431243896 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [35:02<00:00, 7.56it/s] epoch #31: train_loss:3.5275118350982666 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [35:02<00:00, 7.56it/s] epoch #32: train_loss:3.519040822982788 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [35:03<00:00, 7.56it/s] epoch #33: train_loss:3.511171579360962 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:56<00:00, 7.58it/s] epoch #34: train_loss:3.499484062194824 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:56<00:00, 7.58it/s] epoch #35: train_loss:3.4910361766815186 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:57<00:00, 7.58it/s] epoch #36: train_loss:3.4862022399902344 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:54<00:00, 7.59it/s] epoch #37: train_loss:3.4785609245300293 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:46<00:00, 7.62it/s] epoch #38: train_loss:3.465545177459717 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:31<00:00, 7.67it/s] epoch #39: train_loss:3.456451654434204 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:18<00:00, 7.72it/s] epoch #40: train_loss:3.449867010116577 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:19<00:00, 7.72it/s] epoch #41: train_loss:3.447544813156128 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:33<00:00, 7.66it/s] epoch #42: train_loss:3.4404118061065674 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:19<00:00, 7.72it/s] epoch #43: train_loss:3.4278857707977295 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:23<00:00, 7.70it/s] epoch #44: train_loss:3.421074628829956 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:47<00:00, 7.61it/s] epoch #45: train_loss:3.4152135848999023 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:38<00:00, 7.65it/s] epoch #46: train_loss:3.4102437496185303 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:42<00:00, 7.63it/s] epoch #47: train_loss:3.403414249420166 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:43<00:00, 7.63it/s] epoch #48: train_loss:3.391805648803711 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15892/15892 [34:44<00:00, 7.63it/s] epoch #49: train_loss:3.3876454830169678 開始:2022-07-17 21:52:28.838999 終了:2022-07-19 02:48:53.278059
import numpy as np from PIL import Image, ImageTk import cv2 import sys from time import time, sleep from tkinter import Tk, NW, TOP, Frame, Canvas import threading import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms # Pix2PixHDのコードのものを使う from options.test_options import TestOptions from models.models import create_model # GPUを使うかどうか USE_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = torch.device(USE_DEVICE) # アプリケーションが実行中かどうか IS_RUN = True # PyTorchの内部を決定論的に設定する torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 乱数を初期化する np.random.seed(0) torch.manual_seed(0) if not torch.cuda.is_available(): # Pix2PixHDのオプションを上書き sys.argv.append('--gpu_ids') sys.argv.append('-1') # 画像のサイズ IMG_SIZE = 64 # Pix2PixHDのデフォルトオプションを上書きする opt = TestOptions().parse(save=False) opt.nThreads = 1 # スレッド数1 opt.batchSize = 1 # バッチサイズ1 opt.serial_batches = True # シャッフルしない opt.no_flip = True # 反転しない opt.name = 'line2illust' # 保存したモデルのディレクトリ名 opt.which_epoch = 'chapt07-model.pth' # 保存したモデルのファイル名 opt.label_nc = 2 # 入力は2値のラベル opt.no_instance = True # featmapは使わない # Pix2PixHDのモデルを作る model = create_model(opt) model = model.to(USE_DEVICE) # GPUを使うときはGPUメモリ上に乗せる model.eval() # Tkで表示するフレーム class MyFrame(Frame): def __init__(self, parent, **params): Frame.__init__(self, parent, params) # 前回のマウスの位置 self.mousepos = (None, None) # 入力データ self.in_img = np.zeros((64,64), dtype=np.uint8) # 出力データ self.out_img = np.zeros((3,64,64), dtype=np.float32) # キャンバスを配置する self.canvas1 = Canvas(self, width=320, height=320 ) self.canvas1.place(x=20,y=25) self.canvas2 = Canvas(self, width=320, height=320 ) self.canvas2.place(x=380,y=25) # 画像の表示場所を作成する self.image1 = Image.new('RGB',(320,320),(255,255,255)) self.imgtk1 = ImageTk.PhotoImage(self.image1) self.image2 = Image.new('RGB',(320,320),(255,255,255)) self.imgtk2 = ImageTk.PhotoImage(self.image2) self.canvas1.create_image(0,0,image=self.imgtk1,anchor=NW,tag='i') self.canvas2.create_image(0,0,image=self.imgtk2,anchor=NW,tag='o') # キャンバスにマウスイベントを設定 self.canvas1.bind( '', self.mouseOn ) self.canvas1.bind( ' ', self.mouseOff ) self.canvas1.bind( ' ', self.mouseOff ) def mouseOn(self, event): # マウスボタンON又はON状態で移動 x = int(64*event.x/320) # 入力データの位置 y = int(64*event.y/320) curpos = (x, y) # 現在の位置 if self.mousepos[0] is None: # 前回の位置がない self.mousepos = curpos # 前回のマウスの位置から現在の位置まで線を引く self.in_img = cv2.line(self.in_img,self.mousepos,curpos,1,1) self.mousepos = curpos # 前回の位置を更新 def mouseOff(self, event): # マウスボタンOFF又はキャンバスの外 self.mousepos = (None, None) def updateFrame(self): # 入力データを表示させる img = (self.in_img==0)*255 # 入力は背景が0なので反転 img = img.astype(np.uint8) # 8bitデータ # UI上に表示 self.image1 = Image.fromarray(img) self.imgtk1 = ImageTk.PhotoImage(self.image1.resize((320,320))) self.canvas1.itemconfigure(tagOrId='i', image=self.imgtk1) # 出力データを表示させる img = self.out_img # 出力はPix2PixHDが出力する形式なので変換 img = np.clip(img * 255, 0, 255) # 0〜255まで img = img.astype(np.uint8) # 8bitデータ img = img.reshape((3,64,64)) # カラー画像 img = img.transpose((1,2,0)) # 色チャンネルが最後 # UI上に表示 self.image2 = Image.fromarray(img) self.imgtk2 = ImageTk.PhotoImage(self.image2.resize((320,320))) self.canvas2.itemconfigure(tagOrId='o', image=self.imgtk2) # 0.1秒後に再び更新する self.after(100, self.updateFrame) # ニューラルネットワークを実行するスレッド def convert(frame): global IS_RUN s_time = time() with torch.no_grad(): # アプリケーションの実行中は無限ループ while IS_RUN: # キャプチャー速度に合わせて最大fpsを調整する a_time = time() # 線画を変換する image = (frame.in_img != 0).astype(np.int32) image = image.reshape((1,1,64,64)) # 1バッチ、1チャンネル # Pix2PixHDのモデルを実行 img_tensor = torch.tensor(image).to(USE_DEVICE) generated = model.inference(img_tensor, torch.tensor(0)) # Numpy型にする gen_img = generated.detach().cpu().numpy() # 出力された画像をセットする frame.out_img = gen_img # 0.1秒以下だったらその分待つ deltime = (time() - a_time) if deltime < 0.1: sleep(0.1-deltime) # 画面いっぱいにウィンドウを作成する win = Tk() win.geometry('720x380') # ウィンドウの大きさ frame = MyFrame(win, width=720, height=380, bg='gray') # フレーム frame.pack(side=TOP) # ウィンドウに配置 win.after_idle(frame.updateFrame) # 起動後にupdateFrameを呼び出す converter = threading.Thread(target=convert, args=(frame,)) converter.start() win.mainloop() # 処理を開始 # スレッドの終了を待つ IS_RUN = False