diff --git a/src/main/java/com/kamco/cd/training/model/dto/HyperParamDto.java b/src/main/java/com/kamco/cd/training/model/dto/HyperParamDto.java index 76cd4ad..84a0869 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/HyperParamDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/HyperParamDto.java @@ -138,7 +138,6 @@ public class HyperParamDto { this.hueDelta = entity.getHueDelta(); // Legacy - this.dropoutRatio = entity.getDropoutRatio(); this.cnnFilterCnt = entity.getCnnFilterCnt(); // Common diff --git a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java index eecd67b..e275073 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java @@ -77,7 +77,6 @@ public class HyperParamCoreService { .saturationRange(entity.getSaturationRange()) .hueDelta(entity.getHueDelta()) // Legacy - .dropoutRatio(entity.getDropoutRatio()) .cnnFilterCnt(entity.getCnnFilterCnt()) // Common .memo(entity.getMemo()) @@ -247,10 +246,6 @@ public class HyperParamCoreService { createReq.getHueDelta() != null ? createReq.getHueDelta() : baseEntity.getHueDelta()); // Legacy - entity.setDropoutRatio( - createReq.getDropoutRatio() != null - ? createReq.getDropoutRatio() - : baseEntity.getDropoutRatio()); entity.setCnnFilterCnt( createReq.getCnnFilterCnt() != null ? createReq.getCnnFilterCnt() diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java index d4cf6c2..cb33384 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java @@ -2,231 +2,219 @@ package com.kamco.cd.training.postgres.entity; import jakarta.persistence.Column; import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; import jakarta.persistence.Id; import jakarta.persistence.Table; import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; import java.time.ZonedDateTime; -import lombok.AllArgsConstructor; +import java.util.UUID; import lombok.Getter; -import lombok.NoArgsConstructor; import lombok.Setter; import org.hibernate.annotations.ColumnDefault; @Getter @Setter @Entity -@NoArgsConstructor -@AllArgsConstructor -@Table(name = "tb_model_hyper_param") +@Table(name = "tb_model_hyper_params") public class ModelHyperParamEntity { @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + @Column(name = "hyper_param_id", nullable = false) + private Long id; + @NotNull + @ColumnDefault("gen_random_uuid()") + @Column(name = "uuid", nullable = false) + private UUID uuid = UUID.randomUUID(); + @Size(max = 50) + @NotNull @Column(name = "hyper_ver", nullable = false, length = 50) private String hyperVer; - // ==================== Important Parameters ==================== + @Size(max = 20) + @NotNull + @Column(name = "backbone", nullable = false, length = 20) + private String backbone; + + @Size(max = 15) + @NotNull + @Column(name = "input_size", nullable = false, length = 15) + private String inputSize; + + @Size(max = 15) + @NotNull + @Column(name = "crop_size", nullable = false, length = 15) + private String cropSize; + + @NotNull + @Column(name = "epoch_cnt", nullable = false) + private Integer epochCnt; + + @NotNull + @Column(name = "batch_size", nullable = false) + private Integer batchSize; + + @NotNull + @Column(name = "drop_path_rate", nullable = false) + private Double dropPathRate; + + @NotNull + @Column(name = "frozen_stages", nullable = false) + private Integer frozenStages; @Size(max = 20) - @ColumnDefault("'large'") - @Column(name = "backbone", length = 20) - private String backbone = "large"; - - @Size(max = 20) - @ColumnDefault("'256,256'") - @Column(name = "input_size", length = 20) - private String inputSize = "256,256"; - - @Size(max = 20) - @ColumnDefault("'256,256'") - @Column(name = "crop_size", length = 20) - private String cropSize = "256,256"; - - @ColumnDefault("200") - @Column(name = "epoch_cnt") - private Integer epochCnt = 200; - - @ColumnDefault("16") - @Column(name = "batch_size") - private Integer batchSize = 16; - - // ==================== Model Architecture ==================== - - @ColumnDefault("0.3") - @Column(name = "drop_path_rate") - private Double dropPathRate = 0.3; - - @ColumnDefault("-1") - @Column(name = "frozen_stages") - private Integer frozenStages = -1; - - @Size(max = 20) - @ColumnDefault("'abs_diff'") - @Column(name = "neck_policy", length = 20) - private String neckPolicy = "abs_diff"; - - @Size(max = 255) - @ColumnDefault("'512,256,128,64'") - @Column(name = "decoder_channels", length = 255) - private String decoderChannels = "512,256,128,64"; - - @Size(max = 500) - @Column(name = "class_weight", length = 500) - private String classWeight; - - @Column(name = "num_layers") - private Integer numLayers; - - // ==================== Loss & Optimization ==================== - - @ColumnDefault("0.00006") - @Column(name = "learning_rate") - private Double learningRate = 0.00006; - - @ColumnDefault("0.05") - @Column(name = "weight_decay") - private Double weightDecay = 0.05; - - @ColumnDefault("0.9") - @Column(name = "layer_decay_rate") - private Double layerDecayRate = 0.9; - - @ColumnDefault("true") - @Column(name = "ddp_find_unused_params") - private Boolean ddpFindUnusedParams = true; - - @ColumnDefault("255") - @Column(name = "ignore_index") - private Integer ignoreIndex = 255; - - // ==================== Data ==================== - - @ColumnDefault("16") - @Column(name = "train_num_workers") - private Integer trainNumWorkers = 16; - - @ColumnDefault("8") - @Column(name = "val_num_workers") - private Integer valNumWorkers = 8; - - @ColumnDefault("8") - @Column(name = "test_num_workers") - private Integer testNumWorkers = 8; - - @ColumnDefault("true") - @Column(name = "train_shuffle") - private Boolean trainShuffle = true; - - @ColumnDefault("true") - @Column(name = "train_persistent") - private Boolean trainPersistent = true; - - @ColumnDefault("true") - @Column(name = "val_persistent") - private Boolean valPersistent = true; - - // ==================== Evaluation ==================== - - @Size(max = 255) - @ColumnDefault("'mFscore,mIoU'") - @Column(name = "metrics", length = 255) - private String metrics = "mFscore,mIoU"; + @NotNull + @Column(name = "neck_policy", nullable = false, length = 20) + private String neckPolicy; @Size(max = 50) - @ColumnDefault("'changed_fscore'") - @Column(name = "save_best", length = 50) - private String saveBest = "changed_fscore"; + @NotNull + @Column(name = "decoder_channels", nullable = false, length = 50) + private String decoderChannels; - @Size(max = 20) - @ColumnDefault("'greater'") - @Column(name = "save_best_rule", length = 20) - private String saveBestRule = "greater"; + @Size(max = 50) + @NotNull + @Column(name = "class_weight", nullable = false, length = 50) + private String classWeight; - @ColumnDefault("10") - @Column(name = "val_interval") - private Integer valInterval = 10; + @NotNull + @Column(name = "num_layers", nullable = false) + private Integer numLayers; - @ColumnDefault("400") - @Column(name = "log_interval") - private Integer logInterval = 400; + @NotNull + @Column(name = "learning_rate", nullable = false) + private Double learningRate; - @ColumnDefault("1") - @Column(name = "vis_interval") - private Integer visInterval = 1; + @NotNull + @Column(name = "weight_decay", nullable = false) + private Double weightDecay; - // ==================== Hardware ==================== + @NotNull + @Column(name = "layer_decay_rate", nullable = false) + private Double layerDecayRate; - @ColumnDefault("4") - @Column(name = "gpu_cnt") - private Integer gpuCnt = 4; + @NotNull + @Column(name = "ddp_find_unused_params", nullable = false) + private Boolean ddpFindUnusedParams = false; + + @NotNull + @Column(name = "ignore_index", nullable = false) + private Integer ignoreIndex; + + @NotNull + @Column(name = "train_num_workers", nullable = false) + private Integer trainNumWorkers; + + @NotNull + @Column(name = "val_num_workers", nullable = false) + private Integer valNumWorkers; + + @NotNull + @Column(name = "test_num_workers", nullable = false) + private Integer testNumWorkers; + + @NotNull + @Column(name = "train_shuffle", nullable = false) + private Boolean trainShuffle = false; + + @NotNull + @Column(name = "train_persistent", nullable = false) + private Boolean trainPersistent = false; + + @NotNull + @Column(name = "val_persistent", nullable = false) + private Boolean valPersistent = false; @Size(max = 100) - @ColumnDefault("'0,1,2,3'") - @Column(name = "gpu_ids", length = 100) - private String gpuIds = "0,1,2,3"; + @NotNull + @Column(name = "metrics", nullable = false, length = 100) + private String metrics; + + @Size(max = 30) + @NotNull + @Column(name = "save_best", nullable = false, length = 30) + private String saveBest; + + @Size(max = 10) + @NotNull + @Column(name = "save_best_rule", nullable = false, length = 10) + private String saveBestRule; + + @NotNull + @Column(name = "val_interval", nullable = false) + private Integer valInterval; + + @NotNull + @Column(name = "log_interval", nullable = false) + private Integer logInterval; + + @NotNull + @Column(name = "vis_interval", nullable = false) + private Integer visInterval; + + @NotNull + @Column(name = "rot_prob", nullable = false) + private Double rotProb; + + @NotNull + @Column(name = "flip_prob", nullable = false) + private Double flipProb; + + @Size(max = 20) + @NotNull + @Column(name = "rot_degree", nullable = false, length = 20) + private String rotDegree; + + @NotNull + @Column(name = "exchange_prob", nullable = false) + private Double exchangeProb; + + @NotNull + @Column(name = "brightness_delta", nullable = false) + private Integer brightnessDelta; + + @Size(max = 20) + @NotNull + @Column(name = "contrast_range", nullable = false, length = 20) + private String contrastRange; + + @Size(max = 20) + @NotNull + @Column(name = "saturation_range", nullable = false, length = 20) + private String saturationRange; + + @NotNull + @Column(name = "hue_delta", nullable = false) + private Integer hueDelta; + + @NotNull + @Column(name = "gpu_cnt", nullable = false) + private Integer gpuCnt; + + @Size(max = 50) + @Column(name = "gpu_ids", length = 50) + private String gpuIds; - @ColumnDefault("1122") @Column(name = "master_port") - private Integer masterPort = 1122; - - // ==================== Augmentation ==================== - - @ColumnDefault("0.5") - @Column(name = "rot_prob") - private Double rotProb = 0.5; - - @ColumnDefault("0.5") - @Column(name = "flip_prob") - private Double flipProb = 0.5; - - @Size(max = 20) - @ColumnDefault("'-20,20'") - @Column(name = "rot_degree", length = 20) - private String rotDegree = "-20,20"; - - @ColumnDefault("0.5") - @Column(name = "exchange_prob") - private Double exchangeProb = 0.5; - - @ColumnDefault("10") - @Column(name = "brightness_delta") - private Integer brightnessDelta = 10; - - @Size(max = 20) - @ColumnDefault("'0.8,1.2'") - @Column(name = "contrast_range", length = 20) - private String contrastRange = "0.8,1.2"; - - @Size(max = 20) - @ColumnDefault("'0.8,1.2'") - @Column(name = "saturation_range", length = 20) - private String saturationRange = "0.8,1.2"; - - @ColumnDefault("10") - @Column(name = "hue_delta") - private Integer hueDelta = 10; - - // ==================== Legacy (deprecated) ==================== - - @Column(name = "cnn_filter_cnt") - private Integer cnnFilterCnt; - - @Column(name = "dropout_ratio") - private Double dropoutRatio; - - // ==================== Common ==================== + private Integer masterPort; @Column(name = "memo", length = Integer.MAX_VALUE) private String memo; - @Size(max = 255) + @NotNull @ColumnDefault("'N'") - @Column(name = "del_yn", length = 255) - private String delYn = "N"; + @Column(name = "del_yn", nullable = false, length = Integer.MAX_VALUE) + private String delYn; - @ColumnDefault("now()") - @Column(name = "created_dttm") + @NotNull + @ColumnDefault("CURRENT_TIMESTAMP") + @Column(name = "created_dttm", nullable = false) private ZonedDateTime createdDttm; + + @Column(name = "cnn_filter_cnt") + private Integer cnnFilterCnt; }