티스토리 뷰

 

1. 파이토치 모델 구현 하는 법

 

지금까지 텐서플로우, 케라스만 사용하다가 이미지 디텍션을 사용하면서 파이토치를 사용할 기회가 생겼다. 하지만 이게.. 굉장히 쉬운일이 아니었다. 파이토치의 대부분의 구현체는 대부분 모델을 생성할 때 클래스를 사용하고 있기 때문에 텐서플로우와 다르다. 너무 정형화 되어 있어서.. 끼워맞추기이긴 하지만 익숙해지면 쉬워지지 않을까라는 생각에 계속 하고 있다. 

 

 

 

pytorch에서는 데이터셋을 더 쉽게 다룰 수 있도록 다음과 같은 도구를 제공한다. 

torch.utils.data.Dataset

torch.utils.data.DataLoader

 

 

2. 기본적인 구조

2.1 Dataset 

 

1
2
3
4
5
6
class datasetName(torch.utils.data.Dataset): 
  def __init__(self): #1
 
  def __len__(self): #2
 
  def __getitem__(self, idx):  #3
cs

기본적인 데이터 셋 구조는 이거다

 

#1 ->  기본적인 데이터 전처리를 해주는 부분

#2 ->  데이터셋의 길이

#3 -> 데이터셋에서 특정 1개의 샘플을 가져오는 함수 

 

 

2.2 run_training()

실제로 트레이닝을 시행하는 곳이다. train_loader와 val_loader로 나누어서 구성한다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def run_training():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net.to(device)
 
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TrainGlobalConfig.batch_size,
        sampler=RandomSampler(train_dataset),
        pin_memory=False,
        drop_last=True,
        num_workers=TrainGlobalConfig.num_workers
    )
 
    val_loader = torch.utils.data.DataLoader(
        validation_dataset, 
        batch_size=TrainGlobalConfig.batch_size,
        num_workers=TrainGlobalConfig.num_workers,
        shuffle=False,
        sampler=SequentialSampler(validation_dataset),
        pin_memory=False
    )
 
    fitter = Fitter(model=net, device=device, config=TrainGlobalConfig)
    fitter.fit(train_loader, val_loader)
cs

 

2.3 Fitter

- 모델을 학습하고, loss 및 score를 계산하는 부분

 

Class Fitter

- def __init__(self, model, device, config):

- def fit(self, train_loader, validation_loader):

- def validation(self, val_loader):

- def train_one_epoch(self, train_loader):

- def save(self, path):

- def load(self, path):

- def log(self, message):

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class Fitter:
    
    def __init__(self, model, device, config):
        self.config = config
        self.epoch = 0
 
        self.base_dir = f'./{config.folder}'
 
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        
        self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10**5
 
        self.model = model
        self.device = device
 
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias''LayerNorm.bias''LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay'0.001},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay'0.0}
        ] 
 
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.lr)
        self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params)
        self.log(f'Fitter prepared. Device is {self.device}')
 
    def fit(self, train_loader, validation_loader):
        for e in range(self.config.n_epochs):
            if self.config.verbose:
                lr = self.optimizer.param_groups[0]['lr']
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nLR: {lr}')
 
            t = time.time()
            summary_loss = self.train_one_epoch(train_loader)
 
            self.log(f'[RESULT]: Train. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, time: {(time.time() - t):.5f}')
            self.save(f'{self.base_dir}/last-checkpoint.bin')
 
            t = time.time()
            summary_loss = self.validation(validation_loader)
 
            self.log(f'[RESULT]: Val. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, time: {(time.time() - t):.5f}')
            if summary_loss.avg < self.best_summary_loss:
                self.best_summary_loss = summary_loss.avg
                self.model.eval()
                self.save(f'{self.base_dir}/best-checkpoint-{str(self.epoch).zfill(3)}epoch.bin')
                for path in sorted(glob(f'{self.base_dir}/best-checkpoint-*epoch.bin'))[:-3]:
                    os.remove(path)
 
            if self.config.validation_scheduler:
                self.scheduler.step(metrics=summary_loss.avg)
 
            self.epoch += 1
 
    def validation(self, val_loader):
        self.model.eval()
        summary_loss = AverageMeter()
        t = time.time()
        for step, (images, targets, image_ids) in enumerate(val_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Val Step {step}/{len(val_loader)}, ' + \
                        f'summary_loss: {summary_loss.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            with torch.no_grad():
                images = torch.stack(images)
                batch_size = images.shape[0]
                images = images.to(self.device).float()
                boxes = [target['boxes'].to(self.device).float() for target in targets]
                labels = [target['labels'].to(self.device).float() for target in targets]
 
                loss, _, _ = self.model(images, boxes, labels)
                summary_loss.update(loss.detach().item(), batch_size)
 
        return summary_loss
 
    def train_one_epoch(self, train_loader):
        self.model.train()
        summary_loss = AverageMeter()
        t = time.time()
        for step, (images, targets, image_ids) in enumerate(train_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Train Step {step}/{len(train_loader)}, ' + \
                        f'summary_loss: {summary_loss.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            
            images = torch.stack(images)
            images = images.to(self.device).float()
            batch_size = images.shape[0]
            boxes = [target['boxes'].to(self.device).float() for target in targets]
            labels = [target['labels'].to(self.device).float() for target in targets]
 
            self.optimizer.zero_grad()
            
            loss, _, _ = self.model(images, boxes, labels)
            
            loss.backward()
 
            summary_loss.update(loss.detach().item(), batch_size)
 
            self.optimizer.step()
 
            if self.config.step_scheduler:
                self.scheduler.step()
 
        return summary_loss
    
    def save(self, path):
        self.model.eval()
        torch.save({
            'model_state_dict'self.model.model.state_dict(),
            'optimizer_state_dict'self.optimizer.state_dict(),
            'scheduler_state_dict'self.scheduler.state_dict(),
            'best_summary_loss'self.best_summary_loss,
            'epoch'self.epoch,
        }, path)
 
    def load(self, path):
        checkpoint = torch.load(path)
        self.model.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_summary_loss = checkpoint['best_summary_loss']
        self.epoch = checkpoint['epoch'+ 1
        
    def log(self, message):
        if self.config.verbose:
            print(message)
        with open(self.log_path, 'a+'as logger:
            logger.write(f'{message}\n')
cs

 

 

 

 

2.4 get_net()

- net을 구조화 하는 부분이다. 이때 직접 구조를 짜도 되고, 구조를 짜놓은 것을 가지고 와도 된다. 

 

ex1)

 

1
2
3
4
5
6
7
8
9
10
11
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)
    checkpoint = torch.load('../input/efficientdet/efficientdet_d5-ef44aea8.pth')
    net.load_state_dict(checkpoint)
    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
 
net = get_net()
cs

 

 

ex2) 

 

1
2
3
4
5
6
7
def get_net():
    net = timm.create_model('seresnet18', pretrained=True)
    net.last_linear = nn.Linear(in_features=net.last_linear.in_features, out_features=4, bias=True)
    
    return net
 
net = get_net().cuda()
cs

 

 

2.5 TrainGlobalConfig

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class TrainGlobalConfig:
    num_workers = 4 
    batch_size = 32 
    n_epochs = 10 
    lr = 0.001 
 
    folder = 'folderName'
 
    verbose = True
    verbose_step = 1
 
    step_scheduler = False  # do scheduler.step after optimizer.step
    validation_scheduler = True  # do scheduler.step after validation stage loss
 
    SchedulerClass = torch.optim.lr_scheduler.ReduceLROnPlateau
    scheduler_params = dict(
        mode='min',
        factor=0.5,
        patience=2# 1
        verbose=False
        threshold=0.0001,
        threshold_mode='abs',
        cooldown=0
        min_lr=1e-8,
        eps=1e-08
    )
 
cs

 

 

 

 

<출처> 

1. https://wikidocs.net/57165

2. 

 

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
TAG more
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
글 보관함