update
This commit is contained in:
@@ -17,6 +17,7 @@ BASE_DIR = Path(__file__).resolve().parent
|
||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
||||
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
|
||||
MODEL_PATH = BASE_DIR / "results" / "model.pt"
|
||||
CONFIG_PATH = BASE_DIR / "results" / "config_used.json"
|
||||
|
||||
# 使用 platform_utils 中的 resolve_device 函数
|
||||
|
||||
@@ -25,6 +26,7 @@ DEVICE = resolve_device("auto")
|
||||
TIMESTEPS = 200
|
||||
SEQ_LEN = 64
|
||||
BATCH_SIZE = 2
|
||||
CLIP_K = 5.0
|
||||
|
||||
|
||||
def load_vocab():
|
||||
@@ -33,6 +35,19 @@ def load_vocab():
|
||||
|
||||
|
||||
def main():
|
||||
cfg = {}
|
||||
if CONFIG_PATH.exists():
|
||||
with open(str(CONFIG_PATH), "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
timesteps = int(cfg.get("timesteps", TIMESTEPS))
|
||||
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", SEQ_LEN)))
|
||||
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
|
||||
clip_k = float(cfg.get("clip_k", CLIP_K))
|
||||
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||
cond_dim = int(cfg.get("cond_dim", 32))
|
||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
@@ -42,26 +57,46 @@ def main():
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
print("device", DEVICE)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||
cond_vocab_size = 0
|
||||
if use_condition:
|
||||
data_glob = cfg.get("data_glob")
|
||||
if data_glob:
|
||||
base = Path(data_glob).parent
|
||||
pat = Path(data_glob).name
|
||||
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=cond_dim,
|
||||
use_tanh_eps=use_tanh_eps,
|
||||
eps_scale=eps_scale,
|
||||
).to(DEVICE)
|
||||
if MODEL_PATH.exists():
|
||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE)
|
||||
x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
||||
x_cont = torch.randn(batch_size, seq_len, len(cont_cols), device=DEVICE)
|
||||
x_disc = torch.full((batch_size, seq_len, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||
|
||||
# Initialize discrete with mask tokens
|
||||
for i in range(len(disc_cols)):
|
||||
x_disc[:, :, i] = mask_tokens[i]
|
||||
|
||||
for t in reversed(range(TIMESTEPS)):
|
||||
t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
||||
cond = None
|
||||
if use_condition:
|
||||
if cond_vocab_size <= 0:
|
||||
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
||||
|
||||
for t in reversed(range(timesteps)):
|
||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
|
||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||
a_t = alphas[t]
|
||||
@@ -74,6 +109,8 @@ def main():
|
||||
x_cont = mean + torch.sqrt(betas[t]) * noise
|
||||
else:
|
||||
x_cont = mean
|
||||
if clip_k > 0:
|
||||
x_cont = torch.clamp(x_cont, -clip_k, clip_k)
|
||||
|
||||
# Discrete: fill masked positions by sampling logits
|
||||
for i, logit in enumerate(logits):
|
||||
|
||||
Reference in New Issue
Block a user