80/10/10æ°æ®åå详解 | AIçæåç¿»è¯
é®é¢ï¼è½å¦è§£énanoGPTånanochatä¸ä½¿ç¨ç80/10/10è®ç»/éªè¯/æµè¯ååï¼ä»¥åè¿ç§å离为ä½éè¦ï¼
åçï¼
æ¯çï¼å®å ¨æ£ç¡®ã8/1/1ååï¼80%è®ç»ã10%éªè¯ã10%æµè¯ï¼æ¯nanoGPT/nanochatä¸çæ ååæ³ã以䏿¯å ¶åçåå®ç°æ¹å¼ï¼
为ä½éè¦ä¸ä¸ªæ°æ®éï¼
è®ç»éï¼80%ï¼ï¼å¦ä¹ åæ°ãæ¢¯åº¦æ´æ°ä» 卿¤å¤è¿è¡ã
éªè¯éï¼10%ï¼ï¼è°æ´è¶ åæ°å¹¶æ£æµè¿æåãè®ç»è¿ç¨ä¸ï¼æ¯éNæ¥ï¼å¨éªè¯éä¸è¯ä¼°ï¼è§å¯è®ç»éä¸éªè¯éä¹é´çæå¤±å·®å¼ââè¿å°±æ¯è¿æåçä¿¡å·ã
æµè¯éï¼10%ï¼ï¼æç»è¯ä¼°ã卿æè¶ åæ°è°æ´éå®åï¼ä» 䏿¬¡å¨æµè¯éä¸è¿è¡è¯ä¼°ãè¿è½æä¾æ³åæ§è½çæ å估计ã
å ³é®ç¹å¨äºï¼éªè¯éåæµè¯éå¿ é¡»ä¸è®ç»éå®å ¨å离ã妿å¨è®ç»è¿ç¨ä¸ä½¿ç¨æµè¯éï¼ä¾å¦ç¨äºéæ©è¶ åæ°ï¼ï¼å°±ä¼æ³é²ä¿¡æ¯ï¼å¯¼è´æç»ç»æå¤±å»æä¹ã
nanoGPTä¸çå®ç°
# nanoGPTä¸çå
¸ååå
data = np.memmap('data.bin', dtype=np.uint16, mode='r')
n = len(data)
train_data = data[:int(0.8*n)] # 80%ç¨äºè®ç»
val_data = data[int(0.8*n):int(0.9*n)] # 10%ç¨äºéªè¯
test_data = data[int(0.9*n):] # 10%ç¨äºæµè¯
def get_batch(split, batch_size=32):
"""è®ç»æé´éæ ·ä¸ä¸ªæ¹æ¬¡"""
if split == 'train':
ix = torch.randint(len(train_data) - context_length, (batch_size,))
x = torch.stack([torch.from_numpy((train_data[i:i+context_length]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((train_data[i+1:i+1+context_length]).astype(np.int64)) for i in ix])
else: # val æ test
ix = torch.randint(len(val_data) - context_length, (batch_size,))
x = torch.stack([torch.from_numpy((val_data[i:i+context_length]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((val_data[i+1:i+1+context_length]).astype(np.int64)) for i in ix])
return x, y
è®ç»å¾ªç¯æ¨¡å¼
for iter in range(max_iters):
# ä»è®ç»ééæ ·ï¼è®¡ç®æå¤±ï¼ååä¼ æ
logits, loss = model(get_batch('train'))
loss.backward()
optimizer.step()
# æ¯ eval_interval æ¥ï¼æ£æ¥éªè¯æå¤±ï¼æ 梯度ï¼
if iter % eval_interval == 0:
with torch.no_grad():
val_loss = estimate_loss('val') # ä»
å¨éªè¯éä¸ååä¼ æ
print(f"iter {iter}: train_loss {train_loss:.4f}, val_loss {val_loss:.4f}")
éªè¯æå¤±åè¯ä½ æ¯å¦è¿æåï¼
- 妿
val_loss >> train_lossï¼æ¨¡åè®°ä½äºè®ç»æ°æ®ï¼æ²¡ææ³åè½å - 妿䏤è 忥ä¸éï¼æ³åè¯å¥½ï¼ç»§ç»è®ç»
é对è¯è¨æ¨¡åçç¹æ®èé
å¨å¤§è§æ¨¡åºæ¯ä¸ï¼ååçå¿µç¥æååï¼
-
对äºå°æ¨¡åï¼nanoGPTè§æ¨¡ï¼ï¼8/1/1ååå³å¯ãéªè¯éåæµè¯éè§æ¨¡ç¸å½ã
- 对äºå¤§æ¨¡åï¼éªè¯éæ¯æµè¯éæ´éè¦ï¼
- è®ç»è¿ç¨ä¸å¨éªè¯éä¸è¿è¡çè¯ä¼°æ¥éª¤å¤10-100åï¼å½±å梯度å³çï¼
- æµè¯éä» å¨æåè¯ä¼°ä¸æ¬¡ï¼ä¸ç¨äºè°åï¼
- æäºå¢é使ç¨9/0.5/0.5çè³æ´å¤§çè®ç»éæ¯ä¾ï¼å 为æ´å¤æ°æ®æ»æ¯æçç
- 对äºç产模åï¼å¯è½ä¼ä½¿ç¨ä¸ä¸ªä¿ççæµè¯åå¸ï¼ä¸åé¢å/æ¶é´æ®µï¼ï¼ä»¥ææéªè¯éæªè½åç°çåå¸åç§»ã
å¤§æ°æ®éçå®é èé
å¦æä½ å¨æ°å亿tokenä¸è®ç»ï¼å°±åä½ æ£å¨ç¨MI300Xæ¢ç´¢ç飿 ·ï¼ï¼10%çéªè¯/æµè¯éå¯è½è¾¾å°GB级å«ãä½ å¯ä»¥æ´æ¿è¿å°ååï¼
# å¯¹äºæ°å亿tokençæ°æ®éï¼5%éªè¯/5%æµè¯æ¯å¯ä»¥æ¥åç
train_split = 0.90
val_split = 0.95 # 5%éªè¯ï¼5%æµè¯
éªè¯/æµè¯éçç»å¯¹å¤§å°æ¯ç¾åæ¯æ´éè¦ââä½ éè¦è¶³å¤çtokenæ¥è·å¾ç¨³å®çæå¤±ä¼°è®¡ï¼ä½è¶ è¿çº¦1GBçéªè¯æ°æ®åï¼æ¶çéåã
æ»ç»ï¼ä¸¥æ ¼å离éªè¯éåæµè¯éãå¨è®ç»è¿ç¨ä¸ä½¿ç¨éªè¯æå¤±æ¥æ£æµè¿æåå¹¶è°æ´å¦ä¹ ç/dropoutçè¶ åæ°ã卿æåæ°éå®åï¼ä» 䏿¬¡ä½¿ç¨æµè¯éãå¦æä½ çç»æè¦å ·æå¯ä¿¡åº¦ï¼è¿ä¸ç¹ä¸å®¹å¦¥åã
