Skip to content
Snippets Groups Projects
ensemble_models.py 17.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • Michel Spils's avatar
    Michel Spils committed
    """ Ensemble models for WaVo """
    
    import pickle
    
    import lightning.pytorch as pl
    
    Michel Spils's avatar
    Michel Spils committed
    import torch
    from torch import nn, optim
    
    import utils.utility as ut
    
    Michel Spils's avatar
    Michel Spils committed
    
    
    class WeightingModelAttention(nn.Module):
        """
        A PyTorch module that computes the weights for each model in the ensemble.
    
        Args:
            feature_count (int): The number of features in the input data.
            in_size (int): The maximum input size of the models in the ensemble.
            model_count (int): The number of models in the ensemble.
            out_size (int, optional): The output size of the models in the ensemble. Defaults to 48.
            hidden_size (int, optional): The number of nodes in the hidden layer of the weighting model. Defaults to 256.
            num_layers (int, optional): The number of layers in the weighting model. Defaults to 2. 
            dropout (float, optional): The dropout rate of the weighting model. Defaults to 0.25.
        """
        def __init__(self,feature_count,in_size,model_count,out_size=48,hidden_size=256,num_layers=2,dropout=0.25,norm_func='softmax'):
            super().__init__()
            self.model_count = model_count
            self.out_size = out_size
            self.activation = nn.ReLU()
            self.norm_func_name = norm_func
    
            self.d_out_kq = 64
    
            # hier ist der ensemble output die query
            self.W_query = nn.Linear(model_count,self.d_out_kq,bias=False)
            self.W_key = nn.Linear(feature_count,self.d_out_kq,bias=False) #TODO turn bias on
            self.W_value = nn.Linear(feature_count,self.d_out_kq,bias=False) #TODO der wert ist frei
    
            # hier ist der ursprüngliche input die query
            #self.W_query = nn.Linear(feature_count,self.d_out_kq,bias=False)
            #self.W_key = nn.Linear(model_count,self.d_out_kq,bias=False) #TODO turn bias on
            #self.W_value = nn.Linear(model_count,self.d_out_kq,bias=False) #TODO der wert ist frei
    
            #self.W_query = nn.Parameter(torch.rand(feature_count,self.d_out_kq))
            #self.W_key = nn.Parameter(torch.rand(model_count,self.d_out_kq))
            #self.W_value = nn.Parameter(torch.rand(model_count,self.d_out_kq)) 
    
            if norm_func == 'softmax':
                self.norm_func = lambda x: nn.functional.softmax(x,dim=1)
            elif norm_func == 'minmax':
                self.norm_func = lambda x: ((x - x.min(dim=1,keepdim=True).values.repeat(1,self.model_count,1)) / (x - x.min(dim=1,keepdim=True).values.repeat(1,self.model_count,1)).sum(dim=1,keepdim=True))
    
            #self.l1 = nn.Linear(in_size*feature_count,hidden_size)
            #self.l2 = nn.Linear(hidden_size,model_count*out_size)
    
        def forward(self, x_1,x_2):
            x_1, x_2 = x_2, x_1 #TODO aufräumen
            x_1 = torch.movedim(x_1,2,1)
            #x_2 = torch.movedim(x_2,2,1)
            queries_1 = self.W_query(x_1)
            keys_2 = self.W_key(x_2) # new
            values_2 = self.W_value(x_2)      # new
            attn_scores = queries_1 @ keys_2.mT # new
    
            #queries_1 = x_1 @ self.W_query
            #keys_2 = x_2 @ self.W_key          # new
            #values_2 = x_2 @ self.W_value      # new
            #attn_scores = queries_1 @ keys_2.T # new
    
            
    
            attn_weights = torch.softmax(
                attn_scores / self.d_out_kq**0.5, dim=-1)
            
            context_vec = attn_weights @ values_2
            return context_vec
    
    Michel Spils's avatar
    Michel Spils committed
    
    class WeightingModel(nn.Module):
        """
        A PyTorch module that computes the weights for each model in the ensemble.
    
        Args:
            feature_count (int): The number of features in the input data.
            in_size (int): The maximum input size of the models in the ensemble.
            model_count (int): The number of models in the ensemble.
            out_size (int, optional): The output size of the models in the ensemble. Defaults to 48.
    
            hidden_size (int, optional): The number of nodes in the hidden layer of the weighting model. Defaults to 256.
    
            num_layers (int, optional): The number of layers in the weighting model. Defaults to 2. 
    
            dropout (float, optional): The dropout rate of the weighting model. Defaults to 0.25.
    
    Michel Spils's avatar
    Michel Spils committed
        """
    
        def __init__(self,feature_count,in_size,model_count,out_size=48,hidden_size=256,num_layers=2,dropout=0.25,norm_func='softmax'):
    
    Michel Spils's avatar
    Michel Spils committed
            super().__init__()
            self.model_count = model_count
            self.out_size = out_size
    
            self.activation = nn.ReLU()
            self.norm_func_name = norm_func
    
            layers = [
                nn.Linear(in_size*feature_count,hidden_size),
                nn.Dropout(dropout),
                self.activation,
            ]
            for _ in range(num_layers - 2):
                layers += [
                    nn.Linear(in_features=hidden_size, out_features=hidden_size),
                    self.activation,
                    nn.Dropout(dropout),
                ]
            layers += [nn.Linear(hidden_size,model_count*out_size)]
            self.mlp = nn.Sequential(*layers)
    
            if norm_func == 'softmax':
                self.norm_func = lambda x: nn.functional.softmax(x,dim=1)
            elif norm_func == 'minmax':
                self.norm_func = lambda x: ((x - x.min(dim=1,keepdim=True).values.repeat(1,self.model_count,1)) / (x - x.min(dim=1,keepdim=True).values.repeat(1,self.model_count,1)).sum(dim=1,keepdim=True))
    
            #self.l1 = nn.Linear(in_size*feature_count,hidden_size)
            #self.l2 = nn.Linear(hidden_size,model_count*out_size)
    
    Michel Spils's avatar
    Michel Spils committed
    
        def forward(self, x):
            x = torch.reshape(x, (x.size(0), -1))
    
            x = self.mlp(x)
            #x = self.l1(x)
            #x = self.activation(x)
            #x = self.l2(x)
    
    Michel Spils's avatar
    Michel Spils committed
            x = x.view(x.size(0), self.model_count, self.out_size)
    
            x = self.norm_func(x)
    
    Michel Spils's avatar
    Michel Spils committed
            return x
    
    class WaVoLightningEnsemble(pl.LightningModule):
        """ Since the data for this is normalized using the scaler from the first model in model_list all models should be trained on the same or at least similar data i think?
        """
        # pylint: disable-next=unused-argument
    
        def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate):
    
    Michel Spils's avatar
    Michel Spils committed
            super().__init__()
    
            self.model_list = model_list
            self.max_in_size = max([model.hparams['in_size'] for model in self.model_list])
    
    Michel Spils's avatar
    Michel Spils committed
            self.in_size = self.max_in_size
    
    Michel Spils's avatar
    Michel Spils committed
            self.out_size = model_list[0].hparams['out_size']
            self.feature_count = model_list[0].hparams['feature_count']
    
    
            assert len(set([model.target_idx for model in self.model_list])) == 1, "All models in the ensemble must have the same target_idx"
            self.target_idx = model_list[0].target_idx
            self.model_architecture = 'ensemble'
            self.scaler = model_list[0].scaler
    
    
            #self.weighting = WeightingModelAttention(
    
            self.weighting = WeightingModel(
                self.feature_count,
                self.max_in_size,
                len(model_list),
                self.out_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                dropout=dropout,
                norm_func=norm_func)
    
    Michel Spils's avatar
    Michel Spils committed
    
            #TODO different insizes
            #TODO ModuleList?!
    
            self.save_hyperparameters(ignore=['model_list','scaler'])
            self.save_hyperparameters({"scaler": pickle.dumps(self.scaler)})
    
    Michel Spils's avatar
    Michel Spils committed
            for model in self.model_list:
                model.freeze()
    
        def _common_step(self, batch, batch_idx,dataloader_idx=0):
            """
            Computes the weighted average of the predictions of the models in the ensemble.
    
            Args:
                batch (tuple): A tuple containing the input data and the target data.
                batch_idx (int): The index of the batch.
                dataloader_idx (int, optional): The index of the dataloader. Defaults to 0.
    
            Returns:
                torch.Tensor: The weighted average of the predictions of the models in the ensemble.
            """
    
    Michel Spils's avatar
    Michel Spils committed
            y_hat_list = []
            for model in self.model_list:
    
                #if ut.need_classic_input(model.model_architecture):
                #    temp_batch = batch[0],batch[1]
                #else:
                #    temp_batch = batch
    
    Michel Spils's avatar
    Michel Spils committed
                y_hat_list.append(model.predict_step(batch,batch_idx,dataloader_idx))
    
            #TODO what about scaling?! before/after?
            y_hats = torch.stack(y_hat_list,axis=1)
            w = self.weighting(x)
            y_hat = torch.sum((y_hats*w),axis=1)
    
    Michel Spils's avatar
    Michel Spils committed
            return y_hat
    
    
        # pylint: disable-next=arguments-differ
        def training_step(self, batch, batch_idx):
            """
            Computes the training loss for the ensemble.
    
            Args:
                batch (tuple): A tuple containing the input data and the target data.
                batch_idx (int): The index of the batch.
    
            Returns:
                torch.Tensor: The training loss for the ensemble.
            """
            #This is kinda overengineered, pretty much useless if we only experiment with lstms and don't use the transformer stuff
            if len(batch) == 4:
                _, y, _, _ = batch
                y = y[:, -self.out_size:,self.target_idx]
            elif len(batch) == 2:
                _, y = batch
    
            y_inv =  ut.inv_standard(y, self.scaler.mean_[self.target_idx], self.scaler.scale_[self.target_idx])
            y_hat = self._common_step(batch, batch_idx)
            loss = nn.functional.mse_loss(y_hat,y_inv)
            self.log("hp/train_loss", loss)
            return loss
    
        # pylint: disable-next=arguments-differ
        def validation_step(self, batch, batch_idx):
            """
            Computes the validation loss for the ensemble.
    
            Args:
                batch (tuple): A tuple containing the input data and the target data.
                batch_idx (int): The index of the batch.
            """
    
            #This is kinda overengineered, pretty much useless if we only experiment with lstms and don't use the transformer stuff
            if len(batch) == 4:
                _, y, _, _ = batch
                y = y[:, -self.out_size:,self.target_idx]
            elif len(batch) == 2:
                _, y = batch
    
            y_inv =  ut.inv_standard(y, self.scaler.mean_[self.target_idx], self.scaler.scale_[self.target_idx])
            y_hat = self._common_step(batch, batch_idx)
            val_loss = nn.functional.mse_loss(y_hat, y_inv)
            self.log("hp/val_loss", val_loss, sync_dist=True)
    
    
    
        # pylint: disable-next=arguments-differ
        def test_step(self, batch, batch_idx):
            """
            Computes the test loss for the ensemble.
    
            Args:
                batch (tuple): A tuple containing the input data and the target data.
                batch_idx (int): The index of the batch.
            """
            #This is kinda overengineered, pretty much useless if we only experiment with lstms and don't use the transformer stuff
            if len(batch) == 4:
                _, y, _, _ = batch
                y = y[:, -self.out_size:,self.target_idx]
            elif len(batch) == 2:
                _, y = batch
    
            y_inv =  ut.inv_standard(y, self.scaler.mean_[self.target_idx], self.scaler.scale_[self.target_idx])
            y_hat = self._common_step(batch, batch_idx)
            test_loss = nn.functional.mse_loss(y_hat, y_inv)
            self.log("hp/test_loss", test_loss, sync_dist=True)
    
    
        def predict_step(self, batch, batch_idx, dataloader_idx=0):
            y_hat = self._common_step(batch, batch_idx,dataloader_idx)
    
            return y_hat
    
        # pylint: disable-next=arguments-differ
        def forward_old(self, x): #TODO remove?
            y_hat_list = []
            for model in self.model_list:
                y_hat_list.append(model.predict_step((x,None),batch_idx=1))
    
            y_hats = torch.stack(y_hat_list,axis=1)
            w = self.weighting(x)
            y_hat = torch.sum((y_hats*w),axis=1)
            return y_hat
    
        def configure_optimizers(self):
            optimizer = optim.Adam(
                self.parameters(), lr=self.hparams.learning_rate)
            return optimizer
    
    
    class WaVoLightningAttentionEnsemble(pl.LightningModule):
        """ Since the data for this is normalized using the scaler from the first model in model_list all models should be trained on the same or at least similar data i think?
        """
        # pylint: disable-next=unused-argument
        def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate):
            super().__init__()
    
            self.model_list = model_list
            self.max_in_size = max([model.hparams['in_size'] for model in self.model_list])
            self.out_size = model_list[0].hparams['out_size']
            self.feature_count = model_list[0].hparams['feature_count']
    
            assert len(set([model.target_idx for model in self.model_list])) == 1, "All models in the ensemble must have the same target_idx"
            self.target_idx = model_list[0].target_idx
            self.model_architecture = 'ensemble'
            self.scaler = model_list[0].scaler
    
            self.weighting = WeightingModelAttention(
                self.feature_count,
                self.max_in_size,
                len(model_list),
                self.out_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                dropout=dropout,
                norm_func=norm_func)
    
    
    
            #TODO different insizes
            #TODO ModuleList?!
            self.save_hyperparameters(ignore=['model_list','scaler'])
            self.save_hyperparameters({"scaler": pickle.dumps(self.scaler)})
            for model in self.model_list:
                model.freeze()
    
        def _common_step(self, batch, batch_idx,dataloader_idx=0):
            """
            Computes the weighted average of the predictions of the models in the ensemble.
    
            Args:
                batch (tuple): A tuple containing the input data and the target data.
                batch_idx (int): The index of the batch.
                dataloader_idx (int, optional): The index of the dataloader. Defaults to 0.
    
            Returns:
                torch.Tensor: The weighted average of the predictions of the models in the ensemble.
            """
            x  = batch[0]
            y_hat_list = []
            for model in self.model_list:
                #if ut.need_classic_input(model.model_architecture):
                #    temp_batch = batch[0],batch[1]
                #else:
                #    temp_batch = batch
                #y_hat_list.append(model.predict_step(batch,batch_idx,dataloader_idx))
                y_hat_list.append(model(x))
    
            #TODO what about scaling?! before/after?
            y_hats = torch.stack(y_hat_list,axis=1)
            w = self.weighting(x,y_hats)
            y_hat = torch.sum((y_hats*w),axis=1)
    
            return y_hat
    
    
    
    Michel Spils's avatar
    Michel Spils committed
        # pylint: disable-next=arguments-differ
        def training_step(self, batch, batch_idx):
            """
            Computes the training loss for the ensemble.
    
            Args:
                batch (tuple): A tuple containing the input data and the target data.
                batch_idx (int): The index of the batch.
    
            Returns:
                torch.Tensor: The training loss for the ensemble.
            """
    
            #This is kinda overengineered, pretty much useless if we only experiment with lstms and don't use the transformer stuff
            if len(batch) == 4:
                _, y, _, _ = batch
                y = y[:, -self.out_size:,self.target_idx]
            elif len(batch) == 2:
                _, y = batch
    
            y_inv =  ut.inv_standard(y, self.scaler.mean_[self.target_idx], self.scaler.scale_[self.target_idx])
    
    Michel Spils's avatar
    Michel Spils committed
            y_hat = self._common_step(batch, batch_idx)
    
            loss = nn.functional.mse_loss(y_hat,y_inv)
    
    Michel Spils's avatar
    Michel Spils committed
            self.log("hp/train_loss", loss)
            return loss
    
        # pylint: disable-next=arguments-differ
        def validation_step(self, batch, batch_idx):
            """
            Computes the validation loss for the ensemble.
    
            Args:
                batch (tuple): A tuple containing the input data and the target data.
                batch_idx (int): The index of the batch.
            """
    
    
            #This is kinda overengineered, pretty much useless if we only experiment with lstms and don't use the transformer stuff
            if len(batch) == 4:
                _, y, _, _ = batch
                y = y[:, -self.out_size:,self.target_idx]
            elif len(batch) == 2:
                _, y = batch
    
            y_inv =  ut.inv_standard(y, self.scaler.mean_[self.target_idx], self.scaler.scale_[self.target_idx])
            y_hat = self._common_step(batch, batch_idx)
            val_loss = nn.functional.mse_loss(y_hat, y_inv)
    
    Michel Spils's avatar
    Michel Spils committed
            self.log("hp/val_loss", val_loss, sync_dist=True)
    
    
    Michel Spils's avatar
    Michel Spils committed
        # pylint: disable-next=arguments-differ
        def test_step(self, batch, batch_idx):
            """
            Computes the test loss for the ensemble.
    
            Args:
                batch (tuple): A tuple containing the input data and the target data.
                batch_idx (int): The index of the batch.
            """
    
            #This is kinda overengineered, pretty much useless if we only experiment with lstms and don't use the transformer stuff
            if len(batch) == 4:
                _, y, _, _ = batch
                y = y[:, -self.out_size:,self.target_idx]
            elif len(batch) == 2:
                _, y = batch
    
            y_inv =  ut.inv_standard(y, self.scaler.mean_[self.target_idx], self.scaler.scale_[self.target_idx])
    
    Michel Spils's avatar
    Michel Spils committed
            y_hat = self._common_step(batch, batch_idx)
    
            test_loss = nn.functional.mse_loss(y_hat, y_inv)
            self.log("hp/test_loss", test_loss, sync_dist=True)
    
    Michel Spils's avatar
    Michel Spils committed
    
    
        def predict_step(self, batch, batch_idx, dataloader_idx=0):
            y_hat = self._common_step(batch, batch_idx,dataloader_idx)
    
            return y_hat
    
        # pylint: disable-next=arguments-differ
    
        def forward_old(self, x): #TODO remove?
    
    Michel Spils's avatar
    Michel Spils committed
            y_hat_list = []
            for model in self.model_list:
                y_hat_list.append(model.predict_step((x,None),batch_idx=1))
    
            y_hats = torch.stack(y_hat_list,axis=1)
            w = self.weighting(x)
            y_hat = torch.sum((y_hats*w),axis=1)
            return y_hat
    
        def configure_optimizers(self):
            optimizer = optim.Adam(
                self.parameters(), lr=self.hparams.learning_rate)
            return optimizer