This commit is contained in:
2026-01-22 20:42:10 +08:00
parent f37a8ce179
commit 382c756dfe
10 changed files with 310 additions and 55 deletions

View File

@@ -54,6 +54,7 @@ def parse_args():
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
@@ -64,6 +65,11 @@ def parse_args():
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
return parser.parse_args()
@@ -75,6 +81,15 @@ def main():
if not os.path.exists(args.model_path):
raise SystemExit("missing model file: %s" % args.model_path)
# resolve header source
data_path = args.data_path
if args.data_glob:
base = Path(args.data_glob).parent
pat = Path(args.data_glob).name
matches = sorted(base.glob(pat))
if matches:
data_path = str(matches[0])
split = load_split(args.split_path)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
@@ -89,8 +104,31 @@ def main():
vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(args.device)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
cfg = {}
use_condition = False
cond_vocab_size = 0
if args.config and os.path.exists(args.config):
with open(args.config, "r", encoding="utf-8") as f:
cfg = json.load(f)
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
if use_condition:
base = Path(cfg.get("data_glob", args.data_glob)).parent
pat = Path(cfg.get("data_glob", args.data_glob)).name
cond_vocab_size = len(sorted(base.glob(pat)))
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
eps_scale=float(cfg.get("eps_scale", 1.0)),
).to(device)
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
else:
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
@@ -108,9 +146,20 @@ def main():
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
# condition id
cond = None
if use_condition:
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob")
if args.condition_id < 0:
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
else:
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
@@ -122,6 +171,8 @@ def main():
x_cont = mean_x + torch.sqrt(betas[t]) * noise
else:
x_cont = mean_x
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
for i, logit in enumerate(logits):
if t == 0:
@@ -136,15 +187,22 @@ def main():
)
x_disc[:, :, i][mask] = sampled[mask]
# move to CPU for export
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()
# clip in normalized space to avoid extreme blow-up
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
header = read_header(args.data_path)
header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
if args.include_condition and use_condition:
out_cols = ["__cond_file_id"] + out_cols
os.makedirs(os.path.dirname(args.out), exist_ok=True)
with open(args.out, "w", newline="", encoding="utf-8") as f:
@@ -155,6 +213,8 @@ def main():
for b in range(args.batch_size):
for t in range(args.seq_len):
row = {}
if args.include_condition and use_condition:
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):