update
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
# Example: HAI 21.03 Feature Split
|
# Example: HAI 21.03 Feature Split
|
||||||
|
|
||||||
This folder contains a small, reproducible example that inspects the HAI 21.03
|
This folder contains a small, reproducible example that inspects the HAI 21.03
|
||||||
CSV (train1) and produces a continuous/discrete split using a simple heuristic.
|
CSV (all train*.csv.gz files) and produces a continuous/discrete split using a simple heuristic.
|
||||||
|
|
||||||
## Files
|
## Files
|
||||||
- analyze_hai21_03.py: reads a sample of the data and writes results.
|
- analyze_hai21_03.py: reads a sample of the data and writes results.
|
||||||
@@ -60,6 +60,12 @@ python example/run_pipeline.py --device auto
|
|||||||
- Set `device` in `example/config.json` to `auto` or `cuda` when moving to a GPU machine.
|
- Set `device` in `example/config.json` to `auto` or `cuda` when moving to a GPU machine.
|
||||||
- Attack label columns (`attack*`) are excluded from training and generation.
|
- Attack label columns (`attack*`) are excluded from training and generation.
|
||||||
- `time` column is always excluded from training and generation (optional for export only).
|
- `time` column is always excluded from training and generation (optional for export only).
|
||||||
|
- EMA weights are saved as `model_ema.pt` and used by the pipeline for sampling.
|
||||||
|
- Gradients are clipped by default (`grad_clip` in `config.json`) to stabilize training.
|
||||||
|
- Discrete masking uses a cosine schedule for smoother corruption.
|
||||||
|
- Continuous sampling is clipped in normalized space each step for stability.
|
||||||
|
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
|
||||||
|
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
|
||||||
- The script only samples the first 5000 rows to stay fast.
|
- The script only samples the first 5000 rows to stay fast.
|
||||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||||
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
{
|
{
|
||||||
"data_path": "../../dataset/hai/hai-21.03/train1.csv.gz",
|
"data_path": "../../dataset/hai/hai-21.03/train1.csv.gz",
|
||||||
|
"data_glob": "../../dataset/hai/hai-21.03/train*.csv.gz",
|
||||||
"split_path": "./feature_split.json",
|
"split_path": "./feature_split.json",
|
||||||
"stats_path": "./results/cont_stats.json",
|
"stats_path": "./results/cont_stats.json",
|
||||||
"vocab_path": "./results/disc_vocab.json",
|
"vocab_path": "./results/disc_vocab.json",
|
||||||
@@ -14,5 +15,16 @@
|
|||||||
"lr": 0.0005,
|
"lr": 0.0005,
|
||||||
"seed": 1337,
|
"seed": 1337,
|
||||||
"log_every": 10,
|
"log_every": 10,
|
||||||
"ckpt_every": 50
|
"ckpt_every": 50,
|
||||||
|
"ema_decay": 0.999,
|
||||||
|
"use_ema": true,
|
||||||
|
"clip_k": 5.0,
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
"use_condition": true,
|
||||||
|
"condition_type": "file_id",
|
||||||
|
"cond_dim": 32,
|
||||||
|
"use_tanh_eps": true,
|
||||||
|
"eps_scale": 1.0,
|
||||||
|
"sample_batch_size": 8,
|
||||||
|
"sample_seq_len": 128
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
import csv
|
import csv
|
||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -13,15 +13,18 @@ def load_split(path: str) -> Dict[str, List[str]]:
|
|||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def iter_rows(path: str) -> Iterable[Dict[str, str]]:
|
def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
|
||||||
with gzip.open(path, "rt", newline="") as f:
|
paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths)
|
||||||
reader = csv.DictReader(f)
|
for path in paths:
|
||||||
for row in reader:
|
opener = gzip.open if str(path).endswith(".gz") else open
|
||||||
yield row
|
with opener(path, "rt", newline="") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
yield row
|
||||||
|
|
||||||
|
|
||||||
def compute_cont_stats(
|
def compute_cont_stats(
|
||||||
path: str,
|
path: Union[str, List[str]],
|
||||||
cont_cols: List[str],
|
cont_cols: List[str],
|
||||||
max_rows: Optional[int] = None,
|
max_rows: Optional[int] = None,
|
||||||
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
||||||
@@ -52,7 +55,7 @@ def compute_cont_stats(
|
|||||||
|
|
||||||
|
|
||||||
def build_vocab(
|
def build_vocab(
|
||||||
path: str,
|
path: Union[str, List[str]],
|
||||||
disc_cols: List[str],
|
disc_cols: List[str],
|
||||||
max_rows: Optional[int] = None,
|
max_rows: Optional[int] = None,
|
||||||
) -> Dict[str, Dict[str, int]]:
|
) -> Dict[str, Dict[str, int]]:
|
||||||
@@ -80,7 +83,7 @@ def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[st
|
|||||||
|
|
||||||
|
|
||||||
def windowed_batches(
|
def windowed_batches(
|
||||||
path: str,
|
path: Union[str, List[str]],
|
||||||
cont_cols: List[str],
|
cont_cols: List[str],
|
||||||
disc_cols: List[str],
|
disc_cols: List[str],
|
||||||
vocab: Dict[str, Dict[str, int]],
|
vocab: Dict[str, Dict[str, int]],
|
||||||
@@ -89,10 +92,12 @@ def windowed_batches(
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
max_batches: Optional[int] = None,
|
max_batches: Optional[int] = None,
|
||||||
|
return_file_id: bool = False,
|
||||||
):
|
):
|
||||||
import torch
|
import torch
|
||||||
batch_cont = []
|
batch_cont = []
|
||||||
batch_disc = []
|
batch_disc = []
|
||||||
|
batch_file = []
|
||||||
seq_cont = []
|
seq_cont = []
|
||||||
seq_disc = []
|
seq_disc = []
|
||||||
|
|
||||||
@@ -105,22 +110,34 @@ def windowed_batches(
|
|||||||
seq_disc = []
|
seq_disc = []
|
||||||
|
|
||||||
batches_yielded = 0
|
batches_yielded = 0
|
||||||
for row in iter_rows(path):
|
paths = [path] if isinstance(path, str) else list(path)
|
||||||
cont_row = [float(row[c]) for c in cont_cols]
|
for file_id, p in enumerate(paths):
|
||||||
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
for row in iter_rows(p):
|
||||||
seq_cont.append(cont_row)
|
cont_row = [float(row[c]) for c in cont_cols]
|
||||||
seq_disc.append(disc_row)
|
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
||||||
if len(seq_cont) == seq_len:
|
seq_cont.append(cont_row)
|
||||||
flush_seq()
|
seq_disc.append(disc_row)
|
||||||
if len(batch_cont) == batch_size:
|
if len(seq_cont) == seq_len:
|
||||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
flush_seq()
|
||||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
if return_file_id:
|
||||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
batch_file.append(file_id)
|
||||||
yield x_cont, x_disc
|
if len(batch_cont) == batch_size:
|
||||||
batch_cont = []
|
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||||
batch_disc = []
|
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||||
batches_yielded += 1
|
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||||
if max_batches is not None and batches_yielded >= max_batches:
|
if return_file_id:
|
||||||
return
|
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||||
|
yield x_cont, x_disc, x_file
|
||||||
|
else:
|
||||||
|
yield x_cont, x_disc
|
||||||
|
batch_cont = []
|
||||||
|
batch_disc = []
|
||||||
|
batch_file = []
|
||||||
|
batches_yielded += 1
|
||||||
|
if max_batches is not None and batches_yielded >= max_batches:
|
||||||
|
return
|
||||||
|
# drop partial sequence at file boundary
|
||||||
|
seq_cont = []
|
||||||
|
seq_disc = []
|
||||||
|
|
||||||
# Drop last partial batch for simplicity
|
# Drop last partial batch for simplicity
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ def parse_args():
|
|||||||
base_dir = Path(__file__).resolve().parent
|
base_dir = Path(__file__).resolve().parent
|
||||||
repo_dir = base_dir.parent.parent
|
repo_dir = base_dir.parent.parent
|
||||||
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
|
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
|
||||||
|
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
|
||||||
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
|
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
|
||||||
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
|
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
|
||||||
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
|
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||||
@@ -64,6 +65,11 @@ def parse_args():
|
|||||||
parser.add_argument("--batch-size", type=int, default=2)
|
parser.add_argument("--batch-size", type=int, default=2)
|
||||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||||
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
|
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
|
||||||
|
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
|
||||||
|
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
|
||||||
|
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
|
||||||
|
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
|
||||||
|
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -75,6 +81,15 @@ def main():
|
|||||||
if not os.path.exists(args.model_path):
|
if not os.path.exists(args.model_path):
|
||||||
raise SystemExit("missing model file: %s" % args.model_path)
|
raise SystemExit("missing model file: %s" % args.model_path)
|
||||||
|
|
||||||
|
# resolve header source
|
||||||
|
data_path = args.data_path
|
||||||
|
if args.data_glob:
|
||||||
|
base = Path(args.data_glob).parent
|
||||||
|
pat = Path(args.data_glob).name
|
||||||
|
matches = sorted(base.glob(pat))
|
||||||
|
if matches:
|
||||||
|
data_path = str(matches[0])
|
||||||
|
|
||||||
split = load_split(args.split_path)
|
split = load_split(args.split_path)
|
||||||
time_col = split.get("time_column", "time")
|
time_col = split.get("time_column", "time")
|
||||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
@@ -89,8 +104,31 @@ def main():
|
|||||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
device = resolve_device(args.device)
|
device = resolve_device(args.device)
|
||||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
|
cfg = {}
|
||||||
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
use_condition = False
|
||||||
|
cond_vocab_size = 0
|
||||||
|
if args.config and os.path.exists(args.config):
|
||||||
|
with open(args.config, "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||||
|
if use_condition:
|
||||||
|
base = Path(cfg.get("data_glob", args.data_glob)).parent
|
||||||
|
pat = Path(cfg.get("data_glob", args.data_glob)).name
|
||||||
|
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||||
|
|
||||||
|
model = HybridDiffusionModel(
|
||||||
|
cont_dim=len(cont_cols),
|
||||||
|
disc_vocab_sizes=vocab_sizes,
|
||||||
|
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||||
|
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||||
|
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||||
|
eps_scale=float(cfg.get("eps_scale", 1.0)),
|
||||||
|
).to(device)
|
||||||
|
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
|
||||||
|
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
|
||||||
|
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||||
@@ -108,9 +146,20 @@ def main():
|
|||||||
for i in range(len(disc_cols)):
|
for i in range(len(disc_cols)):
|
||||||
x_disc[:, :, i] = mask_tokens[i]
|
x_disc[:, :, i] = mask_tokens[i]
|
||||||
|
|
||||||
|
# condition id
|
||||||
|
cond = None
|
||||||
|
if use_condition:
|
||||||
|
if cond_vocab_size <= 0:
|
||||||
|
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||||
|
if args.condition_id < 0:
|
||||||
|
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
|
||||||
|
else:
|
||||||
|
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||||
|
cond = cond_id
|
||||||
|
|
||||||
for t in reversed(range(args.timesteps)):
|
for t in reversed(range(args.timesteps)):
|
||||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
a_t = alphas[t]
|
a_t = alphas[t]
|
||||||
a_bar_t = alphas_cumprod[t]
|
a_bar_t = alphas_cumprod[t]
|
||||||
@@ -122,6 +171,8 @@ def main():
|
|||||||
x_cont = mean_x + torch.sqrt(betas[t]) * noise
|
x_cont = mean_x + torch.sqrt(betas[t]) * noise
|
||||||
else:
|
else:
|
||||||
x_cont = mean_x
|
x_cont = mean_x
|
||||||
|
if args.clip_k > 0:
|
||||||
|
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||||
|
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
if t == 0:
|
if t == 0:
|
||||||
@@ -136,15 +187,22 @@ def main():
|
|||||||
)
|
)
|
||||||
x_disc[:, :, i][mask] = sampled[mask]
|
x_disc[:, :, i][mask] = sampled[mask]
|
||||||
|
|
||||||
|
# move to CPU for export
|
||||||
x_cont = x_cont.cpu()
|
x_cont = x_cont.cpu()
|
||||||
x_disc = x_disc.cpu()
|
x_disc = x_disc.cpu()
|
||||||
|
|
||||||
|
# clip in normalized space to avoid extreme blow-up
|
||||||
|
if args.clip_k > 0:
|
||||||
|
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||||
|
|
||||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||||
x_cont = x_cont * std_vec + mean_vec
|
x_cont = x_cont * std_vec + mean_vec
|
||||||
|
|
||||||
header = read_header(args.data_path)
|
header = read_header(data_path)
|
||||||
out_cols = [c for c in header if c != time_col or args.include_time]
|
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||||
|
if args.include_condition and use_condition:
|
||||||
|
out_cols = ["__cond_file_id"] + out_cols
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||||
with open(args.out, "w", newline="", encoding="utf-8") as f:
|
with open(args.out, "w", newline="", encoding="utf-8") as f:
|
||||||
@@ -155,6 +213,8 @@ def main():
|
|||||||
for b in range(args.batch_size):
|
for b in range(args.batch_size):
|
||||||
for t in range(args.seq_len):
|
for t in range(args.seq_len):
|
||||||
row = {}
|
row = {}
|
||||||
|
if args.include_condition and use_condition:
|
||||||
|
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
|
||||||
if args.include_time and time_col in header:
|
if args.include_time and time_col in header:
|
||||||
row[time_col] = str(row_index)
|
row[time_col] = str(row_index)
|
||||||
for i, c in enumerate(cont_cols):
|
for i, c in enumerate(cont_cols):
|
||||||
|
|||||||
@@ -36,9 +36,10 @@ def q_sample_discrete(
|
|||||||
mask_tokens: torch.Tensor,
|
mask_tokens: torch.Tensor,
|
||||||
max_t: int,
|
max_t: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Randomly mask discrete tokens with a linear schedule over t."""
|
"""Randomly mask discrete tokens with a cosine schedule over t."""
|
||||||
bsz = x0.size(0)
|
bsz = x0.size(0)
|
||||||
p = t.float() / float(max_t)
|
# cosine schedule: p(0)=0, p(max_t)=1
|
||||||
|
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
|
||||||
p = p.view(bsz, 1, 1)
|
p = p.view(bsz, 1, 1)
|
||||||
mask = torch.rand_like(x0.float()) < p
|
mask = torch.rand_like(x0.float()) < p
|
||||||
x_masked = x0.clone()
|
x_masked = x0.clone()
|
||||||
@@ -69,12 +70,24 @@ class HybridDiffusionModel(nn.Module):
|
|||||||
disc_vocab_sizes: List[int],
|
disc_vocab_sizes: List[int],
|
||||||
time_dim: int = 64,
|
time_dim: int = 64,
|
||||||
hidden_dim: int = 256,
|
hidden_dim: int = 256,
|
||||||
|
cond_vocab_size: int = 0,
|
||||||
|
cond_dim: int = 32,
|
||||||
|
use_tanh_eps: bool = False,
|
||||||
|
eps_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cont_dim = cont_dim
|
self.cont_dim = cont_dim
|
||||||
self.disc_vocab_sizes = disc_vocab_sizes
|
self.disc_vocab_sizes = disc_vocab_sizes
|
||||||
|
|
||||||
self.time_embed = SinusoidalTimeEmbedding(time_dim)
|
self.time_embed = SinusoidalTimeEmbedding(time_dim)
|
||||||
|
self.use_tanh_eps = use_tanh_eps
|
||||||
|
self.eps_scale = eps_scale
|
||||||
|
|
||||||
|
self.cond_vocab_size = cond_vocab_size
|
||||||
|
self.cond_dim = cond_dim
|
||||||
|
self.cond_embed = None
|
||||||
|
if cond_vocab_size and cond_vocab_size > 0:
|
||||||
|
self.cond_embed = nn.Embedding(cond_vocab_size, cond_dim)
|
||||||
|
|
||||||
self.disc_embeds = nn.ModuleList([
|
self.disc_embeds = nn.ModuleList([
|
||||||
nn.Embedding(vocab_size + 1, min(32, vocab_size * 2))
|
nn.Embedding(vocab_size + 1, min(32, vocab_size * 2))
|
||||||
@@ -83,7 +96,8 @@ class HybridDiffusionModel(nn.Module):
|
|||||||
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
||||||
|
|
||||||
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
||||||
self.in_proj = nn.Linear(cont_dim + disc_embed_dim + time_dim, hidden_dim)
|
in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||||
|
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||||
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
||||||
|
|
||||||
self.cont_head = nn.Linear(hidden_dim, cont_dim)
|
self.cont_head = nn.Linear(hidden_dim, cont_dim)
|
||||||
@@ -92,7 +106,7 @@ class HybridDiffusionModel(nn.Module):
|
|||||||
for vocab_size in disc_vocab_sizes
|
for vocab_size in disc_vocab_sizes
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor):
|
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None):
|
||||||
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
|
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
|
||||||
time_emb = self.time_embed(t)
|
time_emb = self.time_embed(t)
|
||||||
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||||
@@ -102,12 +116,23 @@ class HybridDiffusionModel(nn.Module):
|
|||||||
disc_embs.append(emb(x_disc[:, :, i]))
|
disc_embs.append(emb(x_disc[:, :, i]))
|
||||||
disc_feat = torch.cat(disc_embs, dim=-1)
|
disc_feat = torch.cat(disc_embs, dim=-1)
|
||||||
|
|
||||||
|
cond_feat = None
|
||||||
|
if self.cond_embed is not None:
|
||||||
|
if cond is None:
|
||||||
|
raise ValueError("cond is required when cond_vocab_size > 0")
|
||||||
|
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||||
|
|
||||||
cont_feat = self.cont_proj(x_cont)
|
cont_feat = self.cont_proj(x_cont)
|
||||||
feat = torch.cat([cont_feat, disc_feat, time_emb], dim=-1)
|
parts = [cont_feat, disc_feat, time_emb]
|
||||||
|
if cond_feat is not None:
|
||||||
|
parts.append(cond_feat)
|
||||||
|
feat = torch.cat(parts, dim=-1)
|
||||||
feat = self.in_proj(feat)
|
feat = self.in_proj(feat)
|
||||||
|
|
||||||
out, _ = self.backbone(feat)
|
out, _ = self.backbone(feat)
|
||||||
|
|
||||||
eps_pred = self.cont_head(out)
|
eps_pred = self.cont_head(out)
|
||||||
|
if self.use_tanh_eps:
|
||||||
|
eps_pred = torch.tanh(eps_pred) * self.eps_scale
|
||||||
logits = [head(out) for head in self.disc_heads]
|
logits = [head(out) for head in self.disc_heads]
|
||||||
return eps_pred, logits
|
return eps_pred, logits
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from platform_utils import safe_path, ensure_dir
|
|||||||
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
REPO_DIR = BASE_DIR.parent.parent
|
REPO_DIR = BASE_DIR.parent.parent
|
||||||
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
|
DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"
|
||||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
||||||
OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
|
OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
|
||||||
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
|
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
|
||||||
@@ -22,8 +22,13 @@ def main(max_rows: Optional[int] = None):
|
|||||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
mean, std = compute_cont_stats(safe_path(DATA_PATH), cont_cols, max_rows=max_rows)
|
data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz"))
|
||||||
vocab = build_vocab(safe_path(DATA_PATH), disc_cols, max_rows=max_rows)
|
if not data_paths:
|
||||||
|
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
||||||
|
data_paths = [safe_path(p) for p in data_paths]
|
||||||
|
|
||||||
|
mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
|
||||||
|
vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows)
|
||||||
|
|
||||||
ensure_dir(OUT_STATS.parent)
|
ensure_dir(OUT_STATS.parent)
|
||||||
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import argparse
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
from platform_utils import safe_path, is_windows
|
from platform_utils import safe_path, is_windows
|
||||||
|
|
||||||
@@ -40,6 +41,13 @@ def parse_args():
|
|||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
base_dir = Path(__file__).resolve().parent
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
config_path = Path(args.config)
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
timesteps = cfg.get("timesteps", 200)
|
||||||
|
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
|
||||||
|
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
|
||||||
|
clip_k = cfg.get("clip_k", 5.0)
|
||||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
run([sys.executable, str(base_dir / "prepare_data.py")])
|
||||||
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
|
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
|
||||||
run(
|
run(
|
||||||
@@ -49,6 +57,17 @@ def main():
|
|||||||
"--include-time",
|
"--include-time",
|
||||||
"--device",
|
"--device",
|
||||||
args.device,
|
args.device,
|
||||||
|
"--config",
|
||||||
|
str(config_path),
|
||||||
|
"--timesteps",
|
||||||
|
str(timesteps),
|
||||||
|
"--seq-len",
|
||||||
|
str(seq_len),
|
||||||
|
"--batch-size",
|
||||||
|
str(batch_size),
|
||||||
|
"--clip-k",
|
||||||
|
str(clip_k),
|
||||||
|
"--use-ema",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ BASE_DIR = Path(__file__).resolve().parent
|
|||||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
||||||
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
|
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
|
||||||
MODEL_PATH = BASE_DIR / "results" / "model.pt"
|
MODEL_PATH = BASE_DIR / "results" / "model.pt"
|
||||||
|
CONFIG_PATH = BASE_DIR / "results" / "config_used.json"
|
||||||
|
|
||||||
# 使用 platform_utils 中的 resolve_device 函数
|
# 使用 platform_utils 中的 resolve_device 函数
|
||||||
|
|
||||||
@@ -25,6 +26,7 @@ DEVICE = resolve_device("auto")
|
|||||||
TIMESTEPS = 200
|
TIMESTEPS = 200
|
||||||
SEQ_LEN = 64
|
SEQ_LEN = 64
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
|
CLIP_K = 5.0
|
||||||
|
|
||||||
|
|
||||||
def load_vocab():
|
def load_vocab():
|
||||||
@@ -33,6 +35,19 @@ def load_vocab():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
cfg = {}
|
||||||
|
if CONFIG_PATH.exists():
|
||||||
|
with open(str(CONFIG_PATH), "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
timesteps = int(cfg.get("timesteps", TIMESTEPS))
|
||||||
|
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", SEQ_LEN)))
|
||||||
|
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
|
||||||
|
clip_k = float(cfg.get("clip_k", CLIP_K))
|
||||||
|
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||||
|
cond_dim = int(cfg.get("cond_dim", 32))
|
||||||
|
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||||
|
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||||
|
|
||||||
split = load_split(str(SPLIT_PATH))
|
split = load_split(str(SPLIT_PATH))
|
||||||
time_col = split.get("time_column", "time")
|
time_col = split.get("time_column", "time")
|
||||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
@@ -42,26 +57,46 @@ def main():
|
|||||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
print("device", DEVICE)
|
print("device", DEVICE)
|
||||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
cond_vocab_size = 0
|
||||||
|
if use_condition:
|
||||||
|
data_glob = cfg.get("data_glob")
|
||||||
|
if data_glob:
|
||||||
|
base = Path(data_glob).parent
|
||||||
|
pat = Path(data_glob).name
|
||||||
|
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||||
|
model = HybridDiffusionModel(
|
||||||
|
cont_dim=len(cont_cols),
|
||||||
|
disc_vocab_sizes=vocab_sizes,
|
||||||
|
cond_vocab_size=cond_vocab_size,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
use_tanh_eps=use_tanh_eps,
|
||||||
|
eps_scale=eps_scale,
|
||||||
|
).to(DEVICE)
|
||||||
if MODEL_PATH.exists():
|
if MODEL_PATH.exists():
|
||||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE)
|
x_cont = torch.randn(batch_size, seq_len, len(cont_cols), device=DEVICE)
|
||||||
x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
x_disc = torch.full((batch_size, seq_len, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
||||||
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||||
|
|
||||||
# Initialize discrete with mask tokens
|
# Initialize discrete with mask tokens
|
||||||
for i in range(len(disc_cols)):
|
for i in range(len(disc_cols)):
|
||||||
x_disc[:, :, i] = mask_tokens[i]
|
x_disc[:, :, i] = mask_tokens[i]
|
||||||
|
|
||||||
for t in reversed(range(TIMESTEPS)):
|
cond = None
|
||||||
t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long)
|
if use_condition:
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
if cond_vocab_size <= 0:
|
||||||
|
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||||
|
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
||||||
|
|
||||||
|
for t in reversed(range(timesteps)):
|
||||||
|
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||||
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||||
a_t = alphas[t]
|
a_t = alphas[t]
|
||||||
@@ -74,6 +109,8 @@ def main():
|
|||||||
x_cont = mean + torch.sqrt(betas[t]) * noise
|
x_cont = mean + torch.sqrt(betas[t]) * noise
|
||||||
else:
|
else:
|
||||||
x_cont = mean
|
x_cont = mean
|
||||||
|
if clip_k > 0:
|
||||||
|
x_cont = torch.clamp(x_cont, -clip_k, clip_k)
|
||||||
|
|
||||||
# Discrete: fill masked positions by sampling logits
|
# Discrete: fill masked positions by sampling logits
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ REPO_DIR = BASE_DIR.parent.parent
|
|||||||
|
|
||||||
DEFAULTS = {
|
DEFAULTS = {
|
||||||
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
|
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
|
||||||
|
"data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz",
|
||||||
"split_path": BASE_DIR / "feature_split.json",
|
"split_path": BASE_DIR / "feature_split.json",
|
||||||
"stats_path": BASE_DIR / "results" / "cont_stats.json",
|
"stats_path": BASE_DIR / "results" / "cont_stats.json",
|
||||||
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
|
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
|
||||||
@@ -41,6 +42,15 @@ DEFAULTS = {
|
|||||||
"seed": 1337,
|
"seed": 1337,
|
||||||
"log_every": 10,
|
"log_every": 10,
|
||||||
"ckpt_every": 50,
|
"ckpt_every": 50,
|
||||||
|
"ema_decay": 0.999,
|
||||||
|
"use_ema": True,
|
||||||
|
"clip_k": 5.0,
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
"use_condition": True,
|
||||||
|
"condition_type": "file_id",
|
||||||
|
"cond_dim": 32,
|
||||||
|
"use_tanh_eps": True,
|
||||||
|
"eps_scale": 1.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -69,7 +79,7 @@ def parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def resolve_config_paths(config, base_dir: Path):
|
def resolve_config_paths(config, base_dir: Path):
|
||||||
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
|
keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key in config:
|
if key in config:
|
||||||
# 如果值是字符串,转换为Path对象
|
# 如果值是字符串,转换为Path对象
|
||||||
@@ -85,6 +95,26 @@ def resolve_config_paths(config, base_dir: Path):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class EMA:
|
||||||
|
def __init__(self, model, decay: float):
|
||||||
|
self.decay = decay
|
||||||
|
self.shadow = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
self.shadow[name] = param.detach().clone()
|
||||||
|
|
||||||
|
def update(self, model):
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
old = self.shadow[name]
|
||||||
|
self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.shadow
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
config = dict(DEFAULTS)
|
config = dict(DEFAULTS)
|
||||||
@@ -113,25 +143,47 @@ def main():
|
|||||||
vocab = load_json(config["vocab_path"])["vocab"]
|
vocab = load_json(config["vocab_path"])["vocab"]
|
||||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
|
data_paths = None
|
||||||
|
if "data_glob" in config and config["data_glob"]:
|
||||||
|
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
|
||||||
|
if data_paths:
|
||||||
|
data_paths = [safe_path(p) for p in data_paths]
|
||||||
|
if not data_paths:
|
||||||
|
data_paths = [safe_path(config["data_path"])]
|
||||||
|
|
||||||
|
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
|
||||||
|
cond_vocab_size = len(data_paths) if use_condition else 0
|
||||||
|
|
||||||
device = resolve_device(str(config["device"]))
|
device = resolve_device(str(config["device"]))
|
||||||
print("device", device)
|
print("device", device)
|
||||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
|
model = HybridDiffusionModel(
|
||||||
|
cont_dim=len(cont_cols),
|
||||||
|
disc_vocab_sizes=vocab_sizes,
|
||||||
|
cond_vocab_size=cond_vocab_size,
|
||||||
|
cond_dim=int(config.get("cond_dim", 32)),
|
||||||
|
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||||
|
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||||
|
).to(device)
|
||||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||||
|
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||||
|
|
||||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
os.makedirs(config["out_dir"], exist_ok=True)
|
os.makedirs(config["out_dir"], exist_ok=True)
|
||||||
log_path = os.path.join(config["out_dir"], "train_log.csv")
|
out_dir = safe_path(config["out_dir"])
|
||||||
|
log_path = os.path.join(out_dir, "train_log.csv")
|
||||||
with open(log_path, "w", encoding="utf-8") as f:
|
with open(log_path, "w", encoding="utf-8") as f:
|
||||||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||||
|
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
total_step = 0
|
total_step = 0
|
||||||
for epoch in range(int(config["epochs"])):
|
for epoch in range(int(config["epochs"])):
|
||||||
for step, (x_cont, x_disc) in enumerate(
|
for step, batch in enumerate(
|
||||||
windowed_batches(
|
windowed_batches(
|
||||||
config["data_path"],
|
data_paths,
|
||||||
cont_cols,
|
cont_cols,
|
||||||
disc_cols,
|
disc_cols,
|
||||||
vocab,
|
vocab,
|
||||||
@@ -140,8 +192,15 @@ def main():
|
|||||||
batch_size=int(config["batch_size"]),
|
batch_size=int(config["batch_size"]),
|
||||||
seq_len=int(config["seq_len"]),
|
seq_len=int(config["seq_len"]),
|
||||||
max_batches=int(config["max_batches"]),
|
max_batches=int(config["max_batches"]),
|
||||||
|
return_file_id=use_condition,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
if use_condition:
|
||||||
|
x_cont, x_disc, cond = batch
|
||||||
|
cond = cond.to(device)
|
||||||
|
else:
|
||||||
|
x_cont, x_disc = batch
|
||||||
|
cond = None
|
||||||
x_cont = x_cont.to(device)
|
x_cont = x_cont.to(device)
|
||||||
x_disc = x_disc.to(device)
|
x_disc = x_disc.to(device)
|
||||||
|
|
||||||
@@ -153,21 +212,29 @@ def main():
|
|||||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||||
|
|
||||||
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||||
|
|
||||||
loss_cont = F.mse_loss(eps_pred, noise)
|
loss_cont = F.mse_loss(eps_pred, noise)
|
||||||
loss_disc = 0.0
|
loss_disc = 0.0
|
||||||
|
loss_disc_count = 0
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
if mask[:, :, i].any():
|
if mask[:, :, i].any():
|
||||||
loss_disc = loss_disc + F.cross_entropy(
|
loss_disc = loss_disc + F.cross_entropy(
|
||||||
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
||||||
)
|
)
|
||||||
|
loss_disc_count += 1
|
||||||
|
if loss_disc_count > 0:
|
||||||
|
loss_disc = loss_disc / loss_disc_count
|
||||||
|
|
||||||
lam = float(config["lambda"])
|
lam = float(config["lambda"])
|
||||||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
if float(config.get("grad_clip", 0.0)) > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
||||||
opt.step()
|
opt.step()
|
||||||
|
if ema is not None:
|
||||||
|
ema.update(model)
|
||||||
|
|
||||||
if step % int(config["log_every"]) == 0:
|
if step % int(config["log_every"]) == 0:
|
||||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||||
@@ -185,9 +252,13 @@ def main():
|
|||||||
"config": config,
|
"config": config,
|
||||||
"step": total_step,
|
"step": total_step,
|
||||||
}
|
}
|
||||||
torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt"))
|
if ema is not None:
|
||||||
|
ckpt["ema"] = ema.state_dict()
|
||||||
|
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
||||||
|
|
||||||
torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt"))
|
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
||||||
|
if ema is not None:
|
||||||
|
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
3
requirements.txt
Normal file
3
requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
torch
|
||||||
|
numpy
|
||||||
|
matplotlib
|
||||||
Reference in New Issue
Block a user