[{"data":1,"prerenderedAt":12666},["ShallowReactive",2],{"category-data-deep-learning":3},[4],{"id":5,"title":6,"body":7,"description":12654,"extension":12655,"meta":12656,"navigation":2724,"ogImage":12658,"path":12662,"seo":12663,"stem":12664,"__hash__":12665},"content/blogs/1. transformer复习笔记.md","transformer 复习笔记",{"type":8,"value":9,"toc":12638},"minimark",[10,14,18,35,38,45,48,52,55,315,318,442,758,1190,1196,1467,2386,2389,2618,2621,2861,2868,2871,2874,2879,3211,4558,4575,4801,5260,5263,5358,5361,5948,6309,6665,7022,7295,7298,7301,7634,7791,7794,7797,7989,8128,8131,8643,9042,9046,9049,9661,9665,9668,10668,10671,10674,10965,10968,11125,11136,11730,11733,12440,12443,12621,12624,12629,12634],[11,12,13],"h2",{"id":13},"背景",[15,16,17],"p",{},"距离首次学习 transformer 已经过去一年，最近在回忆 transformer 的相关内容，发现相关内容忘的差不多了，于是决定复习一下并记录下来。",[15,19,20,21,28,29,34],{},"本文主要记录对transformer的理解和思考，代码实现完全参考 ",[22,23,27],"a",{"href":24,"rel":25},"https://zh.d2l.ai/chapter_attention-mechanisms/transformer.html",[26],"nofollow","动手学深度学习"," 部分，部分理论的解释参考 ",[22,30,33],{"href":31,"rel":32},"https://nlp.seas.harvard.edu/annotated-transformer/",[26],"The Annotated Transformer"," 内容。",[11,36,37],{"id":37},"整体框架",[15,39,40],{},[41,42],"img",{"alt":43,"src":44},"transformer架构图","transformer/transformer.png",[11,46,47],{"id":47},"注意力机制",[49,50,51],"h3",{"id":51},"注意力函数",[15,53,54],{},"有两种常见的注意力函数，分别为加性注意力和缩放点积注意力。",[15,56,57,58,314],{},"点积注意力函数与缩放点积注意力函数基本相同，但缩放点积注意力函数有一个额外的缩放因子 ",[59,60,63,102],"span",{"className":61},[62],"katex",[59,64,67],{"className":65},[66],"katex-mathml",[68,69,71],"math",{"xmlns":70},"http://www.w3.org/1998/Math/MathML",[72,73,74,97],"semantics",{},[75,76,77],"mrow",{},[78,79,80,84],"mfrac",{},[81,82,83],"mn",{},"1",[85,86,87],"msqrt",{},[88,89,90,94],"msub",{},[91,92,93],"mi",{},"d",[91,95,96],{},"k",[98,99,101],"annotation",{"encoding":100},"application/x-tex","\\frac{1}{\\sqrt{d_k}}",[59,103,107],{"className":104,"ariaHidden":106},[105],"katex-html","true",[59,108,111,116],{"className":109},[110],"base",[59,112],{"className":113,"style":115},[114],"strut","height:1.3831em;vertical-align:-0.538em;",[59,117,120,125,310],{"className":118},[119],"mord",[59,121],{"className":122},[123,124],"mopen","nulldelimiter",[59,126,128],{"className":127},[78],[59,129,133,301],{"className":130},[131,132],"vlist-t","vlist-t2",[59,134,137,298],{"className":135},[136],"vlist-r",[59,138,142,272,283],{"className":139,"style":141},[140],"vlist","height:0.8451em;",[59,143,145,150],{"style":144},"top:-2.5864em;",[59,146],{"className":147,"style":149},[148],"pstrut","height:3em;",[59,151,157],{"className":152},[153,154,155,156],"sizing","reset-size6","size3","mtight",[59,158,160],{"className":159},[119,156],[59,161,164],{"className":162},[119,163,156],"sqrt",[59,165,167,263],{"className":166},[131,132],[59,168,170,260],{"className":169},[136],[59,171,174,237],{"className":172,"style":173},[140],"height:0.8622em;",[59,175,179,182],{"className":176,"style":178},[177],"svg-align","top:-3em;",[59,180],{"className":181,"style":149},[148],[59,183,186],{"className":184,"style":185},[119,156],"padding-left:0.833em;",[59,187,189,193],{"className":188},[119,156],[59,190,93],{"className":191},[119,192,156],"mathnormal",[59,194,197],{"className":195},[196],"msupsub",[59,198,200,228],{"className":199},[131,132],[59,201,203,223],{"className":202},[136],[59,204,207],{"className":205,"style":206},[140],"height:0.3448em;",[59,208,210,214],{"style":209},"top:-2.3488em;margin-left:0em;margin-right:0.0714em;",[59,211],{"className":212,"style":213},[148],"height:2.5em;",[59,215,219],{"className":216},[153,217,218,156],"reset-size3","size1",[59,220,96],{"className":221,"style":222},[119,192,156],"margin-right:0.03148em;",[59,224,227],{"className":225},[226],"vlist-s","​",[59,229,231],{"className":230},[136],[59,232,235],{"className":233,"style":234},[140],"height:0.1512em;",[59,236],{},[59,238,240,243],{"style":239},"top:-2.8222em;",[59,241],{"className":242,"style":149},[148],[59,244,248],{"className":245,"style":247},[246,156],"hide-tail","min-width:0.853em;height:1.08em;",[249,250,256],"svg",{"xmlns":251,"width":252,"height":253,"viewBox":254,"preserveAspectRatio":255},"http://www.w3.org/2000/svg","400em","1.08em","0 0 400000 1080","xMinYMin slice",[257,258],"path",{"d":259},"M95,702\nc-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14\nc0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54\nc44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10\ns173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429\nc69,-144,104.5,-217.7,106.5,-221\nl0 -0\nc5.3,-9.3,12,-14,20,-14\nH400000v40H845.2724\ns-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7\nc-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z\nM834 80h400000v40h-400000z",[59,261,227],{"className":262},[226],[59,264,266],{"className":265},[136],[59,267,270],{"className":268,"style":269},[140],"height:0.1778em;",[59,271],{},[59,273,275,278],{"style":274},"top:-3.23em;",[59,276],{"className":277,"style":149},[148],[59,279],{"className":280,"style":282},[281],"frac-line","border-bottom-width:0.04em;",[59,284,286,289],{"style":285},"top:-3.394em;",[59,287],{"className":288,"style":149},[148],[59,290,292],{"className":291},[153,154,155,156],[59,293,295],{"className":294},[119,156],[59,296,83],{"className":297},[119,156],[59,299,227],{"className":300},[226],[59,302,304],{"className":303},[136],[59,305,308],{"className":306,"style":307},[140],"height:0.538em;",[59,309],{},[59,311],{"className":312},[313,124],"mclose","。",[15,316,317],{},"加性注意力通过一个具有单个隐藏层的前馈网络计算兼容性函数。虽然两者在理论复杂度上相似，但在实际应用中，点积注意力更快且更节省空间，因为它可以使用高度优化的矩阵乘法代码实现。",[15,319,320,321,351,352,381,382,411,412,314],{},"在实践中，我们通常从小批量的角度来考虑提高效率，例如基于 ",[59,322,324,338],{"className":323},[62],[59,325,327],{"className":326},[66],[68,328,329],{"xmlns":70},[72,330,331,336],{},[75,332,333],{},[91,334,335],{},"n",[98,337,335],{"encoding":100},[59,339,341],{"className":340,"ariaHidden":106},[105],[59,342,344,348],{"className":343},[110],[59,345],{"className":346,"style":347},[114],"height:0.4306em;",[59,349,335],{"className":350},[119,192]," 个查询和 ",[59,353,355,369],{"className":354},[62],[59,356,358],{"className":357},[66],[68,359,360],{"xmlns":70},[72,361,362,367],{},[75,363,364],{},[91,365,366],{},"m",[98,368,366],{"encoding":100},[59,370,372],{"className":371,"ariaHidden":106},[105],[59,373,375,378],{"className":374},[110],[59,376],{"className":377,"style":347},[114],[59,379,366],{"className":380},[119,192]," 个键-值对计算注意力，其中查询和键的长度为 ",[59,383,385,398],{"className":384},[62],[59,386,388],{"className":387},[66],[68,389,390],{"xmlns":70},[72,391,392,396],{},[75,393,394],{},[91,395,93],{},[98,397,93],{"encoding":100},[59,399,401],{"className":400,"ariaHidden":106},[105],[59,402,404,408],{"className":403},[110],[59,405],{"className":406,"style":407},[114],"height:0.6944em;",[59,409,93],{"className":410},[119,192],"，值的长度为 ",[59,413,415,429],{"className":414},[62],[59,416,418],{"className":417},[66],[68,419,420],{"xmlns":70},[72,421,422,427],{},[75,423,424],{},[91,425,426],{},"v",[98,428,426],{"encoding":100},[59,430,432],{"className":431,"ariaHidden":106},[105],[59,433,435,438],{"className":434},[110],[59,436],{"className":437,"style":347},[114],[59,439,426],{"className":440,"style":441},[119,192],"margin-right:0.03588em;",[15,443,444,445,557,558,657,658,757],{},"查询 ",[59,446,448,483],{"className":447},[62],[59,449,451],{"className":450},[66],[68,452,453],{"xmlns":70},[72,454,455,480],{},[75,456,457,460,464],{},[91,458,459],{},"Q",[461,462,463],"mo",{},"∈",[465,466,467,471],"msup",{},[91,468,470],{"mathvariant":469},"double-struck","R",[75,472,473,475,478],{},[91,474,335],{},[461,476,477],{},"×",[91,479,93],{},[98,481,482],{"encoding":100},"Q \\in \\mathbb{R}^{n \\times d}",[59,484,486,508],{"className":485,"ariaHidden":106},[105],[59,487,489,493,496,501,505],{"className":488},[110],[59,490],{"className":491,"style":492},[114],"height:0.8778em;vertical-align:-0.1944em;",[59,494,459],{"className":495},[119,192],[59,497],{"className":498,"style":500},[499],"mspace","margin-right:0.2778em;",[59,502,463],{"className":503},[504],"mrel",[59,506],{"className":507,"style":500},[499],[59,509,511,515],{"className":510},[110],[59,512],{"className":513,"style":514},[114],"height:0.8491em;",[59,516,518,522],{"className":517},[119],[59,519,470],{"className":520},[119,521],"mathbb",[59,523,525],{"className":524},[196],[59,526,528],{"className":527},[131],[59,529,531],{"className":530},[136],[59,532,534],{"className":533,"style":514},[140],[59,535,537,541],{"style":536},"top:-3.063em;margin-right:0.05em;",[59,538],{"className":539,"style":540},[148],"height:2.7em;",[59,542,544],{"className":543},[153,154,155,156],[59,545,547,550,554],{"className":546},[119,156],[59,548,335],{"className":549},[119,192,156],[59,551,477],{"className":552},[553,156],"mbin",[59,555,93],{"className":556},[119,192,156],"、键 ",[59,559,561,590],{"className":560},[62],[59,562,564],{"className":563},[66],[68,565,566],{"xmlns":70},[72,567,568,587],{},[75,569,570,573,575],{},[91,571,572],{},"K",[461,574,463],{},[465,576,577,579],{},[91,578,470],{"mathvariant":469},[75,580,581,583,585],{},[91,582,366],{},[461,584,477],{},[91,586,93],{},[98,588,589],{"encoding":100},"K \\in \\mathbb{R}^{m \\times d}",[59,591,593,613],{"className":592,"ariaHidden":106},[105],[59,594,596,600,604,607,610],{"className":595},[110],[59,597],{"className":598,"style":599},[114],"height:0.7224em;vertical-align:-0.0391em;",[59,601,572],{"className":602,"style":603},[119,192],"margin-right:0.07153em;",[59,605],{"className":606,"style":500},[499],[59,608,463],{"className":609},[504],[59,611],{"className":612,"style":500},[499],[59,614,616,619],{"className":615},[110],[59,617],{"className":618,"style":514},[114],[59,620,622,625],{"className":621},[119],[59,623,470],{"className":624},[119,521],[59,626,628],{"className":627},[196],[59,629,631],{"className":630},[131],[59,632,634],{"className":633},[136],[59,635,637],{"className":636,"style":514},[140],[59,638,639,642],{"style":536},[59,640],{"className":641,"style":540},[148],[59,643,645],{"className":644},[153,154,155,156],[59,646,648,651,654],{"className":647},[119,156],[59,649,366],{"className":650},[119,192,156],[59,652,477],{"className":653},[553,156],[59,655,93],{"className":656},[119,192,156]," 和值 ",[59,659,661,690],{"className":660},[62],[59,662,664],{"className":663},[66],[68,665,666],{"xmlns":70},[72,667,668,687],{},[75,669,670,673,675],{},[91,671,672],{},"V",[461,674,463],{},[465,676,677,679],{},[91,678,470],{"mathvariant":469},[75,680,681,683,685],{},[91,682,366],{},[461,684,477],{},[91,686,426],{},[98,688,689],{"encoding":100},"V \\in \\mathbb{R}^{m \\times v}",[59,691,693,712],{"className":692,"ariaHidden":106},[105],[59,694,696,699,703,706,709],{"className":695},[110],[59,697],{"className":698,"style":599},[114],[59,700,672],{"className":701,"style":702},[119,192],"margin-right:0.22222em;",[59,704],{"className":705,"style":500},[499],[59,707,463],{"className":708},[504],[59,710],{"className":711,"style":500},[499],[59,713,715,719],{"className":714},[110],[59,716],{"className":717,"style":718},[114],"height:0.7713em;",[59,720,722,725],{"className":721},[119],[59,723,470],{"className":724},[119,521],[59,726,728],{"className":727},[196],[59,729,731],{"className":730},[131],[59,732,734],{"className":733},[136],[59,735,737],{"className":736,"style":718},[140],[59,738,739,742],{"style":536},[59,740],{"className":741,"style":540},[148],[59,743,745],{"className":744},[153,154,155,156],[59,746,748,751,754],{"className":747},[119,156],[59,749,366],{"className":750},[119,192,156],[59,752,477],{"className":753},[553,156],[59,755,426],{"className":756,"style":441},[119,192,156]," 的缩放点积注意力是：",[15,759,760],{},[59,761,763,868],{"className":762},[62],[59,764,766],{"className":765},[66],[68,767,768],{"xmlns":70},[72,769,770,865],{},[75,771,772,775,778,780,783,785,787,790,793,795,799,801,804,806,808,810,813,816,820,823,846,848,850,862],{},[91,773,774],{},"A",[91,776,777],{},"t",[91,779,777],{},[91,781,782],{},"e",[91,784,335],{},[91,786,777],{},[91,788,789],{},"i",[91,791,792],{},"o",[91,794,335],{},[461,796,798],{"stretchy":797},"false","(",[91,800,459],{},[461,802,803],{"separator":106},",",[91,805,572],{},[461,807,803],{"separator":106},[91,809,672],{},[461,811,812],{"stretchy":797},")",[461,814,815],{},"=",[91,817,819],{"mathvariant":818},"normal","softmax",[461,821,822],{},"⁡",[75,824,825,827,844],{},[461,826,798],{"fence":106},[78,828,829,840],{},[75,830,831,833],{},[91,832,459],{},[465,834,835,837],{},[91,836,572],{},[91,838,839],{"mathvariant":818},"⊤",[85,841,842],{},[91,843,93],{},[461,845,812],{"fence":106},[91,847,672],{},[461,849,463],{},[465,851,852,854],{},[91,853,470],{"mathvariant":469},[75,855,856,858,860],{},[91,857,335],{},[461,859,477],{},[91,861,426],{},[91,863,864],{"mathvariant":818},".",[98,866,867],{"encoding":100}," Attention(Q,K,V) = \\operatorname{softmax}\\left(\\frac{QK^\\top}{\\sqrt{d}}\\right)V \\in \\mathbb{R}^{n \\times v}.",[59,869,871,941,1143],{"className":870,"ariaHidden":106},[105],[59,872,874,878,881,885,888,891,894,897,900,903,906,909,913,917,920,923,926,929,932,935,938],{"className":873},[110],[59,875],{"className":876,"style":877},[114],"height:1em;vertical-align:-0.25em;",[59,879,774],{"className":880},[119,192],[59,882,884],{"className":883},[119,192],"tt",[59,886,782],{"className":887},[119,192],[59,889,335],{"className":890},[119,192],[59,892,777],{"className":893},[119,192],[59,895,789],{"className":896},[119,192],[59,898,792],{"className":899},[119,192],[59,901,335],{"className":902},[119,192],[59,904,798],{"className":905},[123],[59,907,459],{"className":908},[119,192],[59,910,803],{"className":911},[912],"mpunct",[59,914],{"className":915,"style":916},[499],"margin-right:0.1667em;",[59,918,572],{"className":919,"style":603},[119,192],[59,921,803],{"className":922},[912],[59,924],{"className":925,"style":916},[499],[59,927,672],{"className":928,"style":702},[119,192],[59,930,812],{"className":931},[313],[59,933],{"className":934,"style":500},[499],[59,936,815],{"className":937},[504],[59,939],{"className":940,"style":500},[499],[59,942,944,948,956,959,1128,1131,1134,1137,1140],{"className":943},[110],[59,945],{"className":946,"style":947},[114],"height:1.8em;vertical-align:-0.65em;",[59,949,952],{"className":950},[951],"mop",[59,953,819],{"className":954},[119,955],"mathrm",[59,957],{"className":958,"style":916},[499],[59,960,963,973,1122],{"className":961},[962],"minner",[59,964,968],{"className":965,"style":967},[123,966],"delimcenter","top:0em;",[59,969,798],{"className":970},[971,972],"delimsizing","size2",[59,974,976,979,1119],{"className":975},[119],[59,977],{"className":978},[123,124],[59,980,982],{"className":981},[78],[59,983,985,1111],{"className":984},[131,132],[59,986,988,1108],{"className":987},[136],[59,989,992,1054,1062],{"className":990,"style":991},[140],"height:1.095em;",[59,993,995,998],{"style":994},"top:-2.5335em;",[59,996],{"className":997,"style":149},[148],[59,999,1001],{"className":1000},[153,154,155,156],[59,1002,1004],{"className":1003},[119,156],[59,1005,1007],{"className":1006},[119,163,156],[59,1008,1010,1045],{"className":1009},[131,132],[59,1011,1013,1042],{"className":1012},[136],[59,1014,1017,1029],{"className":1015,"style":1016},[140],"height:0.9378em;",[59,1018,1020,1023],{"className":1019,"style":178},[177],[59,1021],{"className":1022,"style":149},[148],[59,1024,1026],{"className":1025,"style":185},[119,156],[59,1027,93],{"className":1028},[119,192,156],[59,1030,1032,1035],{"style":1031},"top:-2.8978em;",[59,1033],{"className":1034,"style":149},[148],[59,1036,1038],{"className":1037,"style":247},[246,156],[249,1039,1040],{"xmlns":251,"width":252,"height":253,"viewBox":254,"preserveAspectRatio":255},[257,1041],{"d":259},[59,1043,227],{"className":1044},[226],[59,1046,1048],{"className":1047},[136],[59,1049,1052],{"className":1050,"style":1051},[140],"height:0.1022em;",[59,1053],{},[59,1055,1056,1059],{"style":274},[59,1057],{"className":1058,"style":149},[148],[59,1060],{"className":1061,"style":282},[281],[59,1063,1065,1068],{"style":1064},"top:-3.4461em;",[59,1066],{"className":1067,"style":149},[148],[59,1069,1071],{"className":1070},[153,154,155,156],[59,1072,1074,1077],{"className":1073},[119,156],[59,1075,459],{"className":1076},[119,192,156],[59,1078,1080,1083],{"className":1079},[119,156],[59,1081,572],{"className":1082,"style":603},[119,192,156],[59,1084,1086],{"className":1085},[196],[59,1087,1089],{"className":1088},[131],[59,1090,1092],{"className":1091},[136],[59,1093,1096],{"className":1094,"style":1095},[140],"height:0.927em;",[59,1097,1099,1102],{"style":1098},"top:-2.931em;margin-right:0.0714em;",[59,1100],{"className":1101,"style":213},[148],[59,1103,1105],{"className":1104},[153,217,218,156],[59,1106,839],{"className":1107},[119,156],[59,1109,227],{"className":1110},[226],[59,1112,1114],{"className":1113},[136],[59,1115,1117],{"className":1116,"style":307},[140],[59,1118],{},[59,1120],{"className":1121},[313,124],[59,1123,1125],{"className":1124,"style":967},[313,966],[59,1126,812],{"className":1127},[971,972],[59,1129],{"className":1130,"style":916},[499],[59,1132,672],{"className":1133,"style":702},[119,192],[59,1135],{"className":1136,"style":500},[499],[59,1138,463],{"className":1139},[504],[59,1141],{"className":1142,"style":500},[499],[59,1144,1146,1149,1187],{"className":1145},[110],[59,1147],{"className":1148,"style":718},[114],[59,1150,1152,1155],{"className":1151},[119],[59,1153,470],{"className":1154},[119,521],[59,1156,1158],{"className":1157},[196],[59,1159,1161],{"className":1160},[131],[59,1162,1164],{"className":1163},[136],[59,1165,1167],{"className":1166,"style":718},[140],[59,1168,1169,1172],{"style":536},[59,1170],{"className":1171,"style":540},[148],[59,1173,1175],{"className":1174},[153,154,155,156],[59,1176,1178,1181,1184],{"className":1177},[119,156],[59,1179,335],{"className":1180},[119,192,156],[59,1182,477],{"className":1183},[553,156],[59,1185,426],{"className":1186,"style":441},[119,192,156],[59,1188,864],{"className":1189},[119],[15,1191,1192],{},[41,1193],{"alt":1194,"src":1195},"注意力","transformer/attention.png",[15,1197,1198,1199,1273,1274,1345,1346,1466],{},"缩放点积注意力的输入由维度为 ",[59,1200,1202,1220],{"className":1201},[62],[59,1203,1205],{"className":1204},[66],[68,1206,1207],{"xmlns":70},[72,1208,1209,1217],{},[75,1210,1211],{},[88,1212,1213,1215],{},[91,1214,93],{},[91,1216,96],{},[98,1218,1219],{"encoding":100},"d_k",[59,1221,1223],{"className":1222,"ariaHidden":106},[105],[59,1224,1226,1230],{"className":1225},[110],[59,1227],{"className":1228,"style":1229},[114],"height:0.8444em;vertical-align:-0.15em;",[59,1231,1233,1236],{"className":1232},[119],[59,1234,93],{"className":1235},[119,192],[59,1237,1239],{"className":1238},[196],[59,1240,1242,1264],{"className":1241},[131,132],[59,1243,1245,1261],{"className":1244},[136],[59,1246,1249],{"className":1247,"style":1248},[140],"height:0.3361em;",[59,1250,1252,1255],{"style":1251},"top:-2.55em;margin-left:0em;margin-right:0.05em;",[59,1253],{"className":1254,"style":540},[148],[59,1256,1258],{"className":1257},[153,154,155,156],[59,1259,96],{"className":1260,"style":222},[119,192,156],[59,1262,227],{"className":1263},[226],[59,1265,1267],{"className":1266},[136],[59,1268,1271],{"className":1269,"style":1270},[140],"height:0.15em;",[59,1272],{}," ​的查询和键以及维度为 ",[59,1275,1277,1295],{"className":1276},[62],[59,1278,1280],{"className":1279},[66],[68,1281,1282],{"xmlns":70},[72,1283,1284,1292],{},[75,1285,1286],{},[88,1287,1288,1290],{},[91,1289,93],{},[91,1291,426],{},[98,1293,1294],{"encoding":100},"d_v",[59,1296,1298],{"className":1297,"ariaHidden":106},[105],[59,1299,1301,1304],{"className":1300},[110],[59,1302],{"className":1303,"style":1229},[114],[59,1305,1307,1310],{"className":1306},[119],[59,1308,93],{"className":1309},[119,192],[59,1311,1313],{"className":1312},[196],[59,1314,1316,1337],{"className":1315},[131,132],[59,1317,1319,1334],{"className":1318},[136],[59,1320,1323],{"className":1321,"style":1322},[140],"height:0.1514em;",[59,1324,1325,1328],{"style":1251},[59,1326],{"className":1327,"style":540},[148],[59,1329,1331],{"className":1330},[153,154,155,156],[59,1332,426],{"className":1333,"style":441},[119,192,156],[59,1335,227],{"className":1336},[226],[59,1338,1340],{"className":1339},[136],[59,1341,1343],{"className":1342,"style":1270},[140],[59,1344],{}," ​的值组成。我们计算查询与所有键的点积，将每个点积除以 ",[59,1347,1349,1369],{"className":1348},[62],[59,1350,1352],{"className":1351},[66],[68,1353,1354],{"xmlns":70},[72,1355,1356,1366],{},[75,1357,1358],{},[85,1359,1360],{},[88,1361,1362,1364],{},[91,1363,93],{},[91,1365,96],{},[98,1367,1368],{"encoding":100},"\\sqrt{d_k}",[59,1370,1372],{"className":1371,"ariaHidden":106},[105],[59,1373,1375,1379],{"className":1374},[110],[59,1376],{"className":1377,"style":1378},[114],"height:1.04em;vertical-align:-0.1828em;",[59,1380,1382],{"className":1381},[119,163],[59,1383,1385,1457],{"className":1384},[131,132],[59,1386,1388,1454],{"className":1387},[136],[59,1389,1392,1441],{"className":1390,"style":1391},[140],"height:0.8572em;",[59,1393,1395,1398],{"className":1394,"style":178},[177],[59,1396],{"className":1397,"style":149},[148],[59,1399,1401],{"className":1400,"style":185},[119],[59,1402,1404,1407],{"className":1403},[119],[59,1405,93],{"className":1406},[119,192],[59,1408,1410],{"className":1409},[196],[59,1411,1413,1433],{"className":1412},[131,132],[59,1414,1416,1430],{"className":1415},[136],[59,1417,1419],{"className":1418,"style":1248},[140],[59,1420,1421,1424],{"style":1251},[59,1422],{"className":1423,"style":540},[148],[59,1425,1427],{"className":1426},[153,154,155,156],[59,1428,96],{"className":1429,"style":222},[119,192,156],[59,1431,227],{"className":1432},[226],[59,1434,1436],{"className":1435},[136],[59,1437,1439],{"className":1438,"style":1270},[140],[59,1440],{},[59,1442,1444,1447],{"style":1443},"top:-2.8172em;",[59,1445],{"className":1446,"style":149},[148],[59,1448,1450],{"className":1449,"style":247},[246],[249,1451,1452],{"xmlns":251,"width":252,"height":253,"viewBox":254,"preserveAspectRatio":255},[257,1453],{"d":259},[59,1455,227],{"className":1456},[226],[59,1458,1460],{"className":1459},[136],[59,1461,1464],{"className":1462,"style":1463},[140],"height:0.1828em;",[59,1465],{},"​​，并应用softmax函数以获得值上的权重。",[15,1468,1469,1470,1539,1540,1609,1610,1679,1680,1710,1711,1739,1740,1770,1771,1799,1800,2103,2104,1770,2132,2201,2202,314],{},"对于较小的 ",[59,1471,1473,1490],{"className":1472},[62],[59,1474,1476],{"className":1475},[66],[68,1477,1478],{"xmlns":70},[72,1479,1480,1488],{},[75,1481,1482],{},[88,1483,1484,1486],{},[91,1485,93],{},[91,1487,96],{},[98,1489,1219],{"encoding":100},[59,1491,1493],{"className":1492,"ariaHidden":106},[105],[59,1494,1496,1499],{"className":1495},[110],[59,1497],{"className":1498,"style":1229},[114],[59,1500,1502,1505],{"className":1501},[119],[59,1503,93],{"className":1504},[119,192],[59,1506,1508],{"className":1507},[196],[59,1509,1511,1531],{"className":1510},[131,132],[59,1512,1514,1528],{"className":1513},[136],[59,1515,1517],{"className":1516,"style":1248},[140],[59,1518,1519,1522],{"style":1251},[59,1520],{"className":1521,"style":540},[148],[59,1523,1525],{"className":1524},[153,154,155,156],[59,1526,96],{"className":1527,"style":222},[119,192,156],[59,1529,227],{"className":1530},[226],[59,1532,1534],{"className":1533},[136],[59,1535,1537],{"className":1536,"style":1270},[140],[59,1538],{}," 值来说，加性注意力和点积注意力的表现相似，但加性注意力在 ",[59,1541,1543,1560],{"className":1542},[62],[59,1544,1546],{"className":1545},[66],[68,1547,1548],{"xmlns":70},[72,1549,1550,1558],{},[75,1551,1552],{},[88,1553,1554,1556],{},[91,1555,93],{},[91,1557,96],{},[98,1559,1219],{"encoding":100},[59,1561,1563],{"className":1562,"ariaHidden":106},[105],[59,1564,1566,1569],{"className":1565},[110],[59,1567],{"className":1568,"style":1229},[114],[59,1570,1572,1575],{"className":1571},[119],[59,1573,93],{"className":1574},[119,192],[59,1576,1578],{"className":1577},[196],[59,1579,1581,1601],{"className":1580},[131,132],[59,1582,1584,1598],{"className":1583},[136],[59,1585,1587],{"className":1586,"style":1248},[140],[59,1588,1589,1592],{"style":1251},[59,1590],{"className":1591,"style":540},[148],[59,1593,1595],{"className":1594},[153,154,155,156],[59,1596,96],{"className":1597,"style":222},[119,192,156],[59,1599,227],{"className":1600},[226],[59,1602,1604],{"className":1603},[136],[59,1605,1607],{"className":1606,"style":1270},[140],[59,1608],{}," 值较大时会优于点积注意力。我们怀疑在 ",[59,1611,1613,1630],{"className":1612},[62],[59,1614,1616],{"className":1615},[66],[68,1617,1618],{"xmlns":70},[72,1619,1620,1628],{},[75,1621,1622],{},[88,1623,1624,1626],{},[91,1625,93],{},[91,1627,96],{},[98,1629,1219],{"encoding":100},[59,1631,1633],{"className":1632,"ariaHidden":106},[105],[59,1634,1636,1639],{"className":1635},[110],[59,1637],{"className":1638,"style":1229},[114],[59,1640,1642,1645],{"className":1641},[119],[59,1643,93],{"className":1644},[119,192],[59,1646,1648],{"className":1647},[196],[59,1649,1651,1671],{"className":1650},[131,132],[59,1652,1654,1668],{"className":1653},[136],[59,1655,1657],{"className":1656,"style":1248},[140],[59,1658,1659,1662],{"style":1251},[59,1660],{"className":1661,"style":540},[148],[59,1663,1665],{"className":1664},[153,154,155,156],[59,1666,96],{"className":1667,"style":222},[119,192,156],[59,1669,227],{"className":1670},[226],[59,1672,1674],{"className":1673},[136],[59,1675,1677],{"className":1676,"style":1270},[140],[59,1678],{}," 值较大时，点积的值会很大，导致 softmax 集中在很小的一个区域内。为了更好地说明这一点，我们假设查询 ",[59,1681,1683,1697],{"className":1682},[62],[59,1684,1686],{"className":1685},[66],[68,1687,1688],{"xmlns":70},[72,1689,1690,1695],{},[75,1691,1692],{},[91,1693,1694],{},"q",[98,1696,1694],{"encoding":100},[59,1698,1700],{"className":1699,"ariaHidden":106},[105],[59,1701,1703,1707],{"className":1702},[110],[59,1704],{"className":1705,"style":1706},[114],"height:0.625em;vertical-align:-0.1944em;",[59,1708,1694],{"className":1709,"style":441},[119,192]," 和键 ",[59,1712,1714,1727],{"className":1713},[62],[59,1715,1717],{"className":1716},[66],[68,1718,1719],{"xmlns":70},[72,1720,1721,1725],{},[75,1722,1723],{},[91,1724,96],{},[98,1726,96],{"encoding":100},[59,1728,1730],{"className":1729,"ariaHidden":106},[105],[59,1731,1733,1736],{"className":1732},[110],[59,1734],{"className":1735,"style":407},[114],[59,1737,96],{"className":1738,"style":222},[119,192]," 都符合一个均值为 ",[59,1741,1743,1757],{"className":1742},[62],[59,1744,1746],{"className":1745},[66],[68,1747,1748],{"xmlns":70},[72,1749,1750,1755],{},[75,1751,1752],{},[81,1753,1754],{},"0",[98,1756,1754],{"encoding":100},[59,1758,1760],{"className":1759,"ariaHidden":106},[105],[59,1761,1763,1767],{"className":1762},[110],[59,1764],{"className":1765,"style":1766},[114],"height:0.6444em;",[59,1768,1754],{"className":1769},[119],"，方差为 ",[59,1772,1774,1787],{"className":1773},[62],[59,1775,1777],{"className":1776},[66],[68,1778,1779],{"xmlns":70},[72,1780,1781,1785],{},[75,1782,1783],{},[81,1784,83],{},[98,1786,83],{"encoding":100},[59,1788,1790],{"className":1789,"ariaHidden":106},[105],[59,1791,1793,1796],{"className":1792},[110],[59,1794],{"className":1795,"style":1766},[114],[59,1797,83],{"className":1798},[119]," 的随机变量。他们点积的结果 ",[59,1801,1803,1856],{"className":1802},[62],[59,1804,1806],{"className":1805},[66],[68,1807,1808],{"xmlns":70},[72,1809,1810,1853],{},[75,1811,1812,1814,1817,1819,1821,1841,1847],{},[91,1813,1694],{},[461,1815,1816],{},"⋅",[91,1818,96],{},[461,1820,815],{},[1822,1823,1824,1827,1835],"msubsup",{},[461,1825,1826],{},"∑",[75,1828,1829,1831,1833],{},[91,1830,789],{},[461,1832,815],{},[81,1834,83],{},[88,1836,1837,1839],{},[91,1838,93],{},[91,1840,96],{},[88,1842,1843,1845],{},[91,1844,1694],{},[91,1846,789],{},[88,1848,1849,1851],{},[91,1850,96],{},[91,1852,789],{},[98,1854,1855],{"encoding":100},"q \\cdot k = \\sum_{i=1}^{d_k}q_{i}k_{i}",[59,1857,1859,1879,1897],{"className":1858,"ariaHidden":106},[105],[59,1860,1862,1866,1869,1873,1876],{"className":1861},[110],[59,1863],{"className":1864,"style":1865},[114],"height:0.6389em;vertical-align:-0.1944em;",[59,1867,1694],{"className":1868,"style":441},[119,192],[59,1870],{"className":1871,"style":1872},[499],"margin-right:0.2222em;",[59,1874,1816],{"className":1875},[553],[59,1877],{"className":1878,"style":1872},[499],[59,1880,1882,1885,1888,1891,1894],{"className":1881},[110],[59,1883],{"className":1884,"style":407},[114],[59,1886,96],{"className":1887,"style":222},[119,192],[59,1889],{"className":1890,"style":500},[499],[59,1892,815],{"className":1893},[504],[59,1895],{"className":1896,"style":500},[499],[59,1898,1900,1904,2011,2014,2059],{"className":1899},[110],[59,1901],{"className":1902,"style":1903},[114],"height:1.2887em;vertical-align:-0.2997em;",[59,1905,1907,1913],{"className":1906},[951],[59,1908,1826],{"className":1909,"style":1912},[951,1910,1911],"op-symbol","small-op","position:relative;top:0em;",[59,1914,1916],{"className":1915},[196],[59,1917,1919,2002],{"className":1918},[131,132],[59,1920,1922,1999],{"className":1921},[136],[59,1923,1926,1947],{"className":1924,"style":1925},[140],"height:0.989em;",[59,1927,1929,1932],{"style":1928},"top:-2.4003em;margin-left:0em;margin-right:0.05em;",[59,1930],{"className":1931,"style":540},[148],[59,1933,1935],{"className":1934},[153,154,155,156],[59,1936,1938,1941,1944],{"className":1937},[119,156],[59,1939,789],{"className":1940},[119,192,156],[59,1942,815],{"className":1943},[504,156],[59,1945,83],{"className":1946},[119,156],[59,1948,1950,1953],{"style":1949},"top:-3.2029em;margin-right:0.05em;",[59,1951],{"className":1952,"style":540},[148],[59,1954,1956],{"className":1955},[153,154,155,156],[59,1957,1959],{"className":1958},[119,156],[59,1960,1962,1965],{"className":1961},[119,156],[59,1963,93],{"className":1964},[119,192,156],[59,1966,1968],{"className":1967},[196],[59,1969,1971,1991],{"className":1970},[131,132],[59,1972,1974,1988],{"className":1973},[136],[59,1975,1977],{"className":1976,"style":206},[140],[59,1978,1979,1982],{"style":209},[59,1980],{"className":1981,"style":213},[148],[59,1983,1985],{"className":1984},[153,217,218,156],[59,1986,96],{"className":1987,"style":222},[119,192,156],[59,1989,227],{"className":1990},[226],[59,1992,1994],{"className":1993},[136],[59,1995,1997],{"className":1996,"style":234},[140],[59,1998],{},[59,2000,227],{"className":2001},[226],[59,2003,2005],{"className":2004},[136],[59,2006,2009],{"className":2007,"style":2008},[140],"height:0.2997em;",[59,2010],{},[59,2012],{"className":2013,"style":916},[499],[59,2015,2017,2020],{"className":2016},[119],[59,2018,1694],{"className":2019,"style":441},[119,192],[59,2021,2023],{"className":2022},[196],[59,2024,2026,2051],{"className":2025},[131,132],[59,2027,2029,2048],{"className":2028},[136],[59,2030,2033],{"className":2031,"style":2032},[140],"height:0.3117em;",[59,2034,2036,2039],{"style":2035},"top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;",[59,2037],{"className":2038,"style":540},[148],[59,2040,2042],{"className":2041},[153,154,155,156],[59,2043,2045],{"className":2044},[119,156],[59,2046,789],{"className":2047},[119,192,156],[59,2049,227],{"className":2050},[226],[59,2052,2054],{"className":2053},[136],[59,2055,2057],{"className":2056,"style":1270},[140],[59,2058],{},[59,2060,2062,2065],{"className":2061},[119],[59,2063,96],{"className":2064,"style":222},[119,192],[59,2066,2068],{"className":2067},[196],[59,2069,2071,2095],{"className":2070},[131,132],[59,2072,2074,2092],{"className":2073},[136],[59,2075,2077],{"className":2076,"style":2032},[140],[59,2078,2080,2083],{"style":2079},"top:-2.55em;margin-left:-0.0315em;margin-right:0.05em;",[59,2081],{"className":2082,"style":540},[148],[59,2084,2086],{"className":2085},[153,154,155,156],[59,2087,2089],{"className":2088},[119,156],[59,2090,789],{"className":2091},[119,192,156],[59,2093,227],{"className":2094},[226],[59,2096,2098],{"className":2097},[136],[59,2099,2101],{"className":2100,"style":1270},[140],[59,2102],{}," 的均值为 ",[59,2105,2107,2120],{"className":2106},[62],[59,2108,2110],{"className":2109},[66],[68,2111,2112],{"xmlns":70},[72,2113,2114,2118],{},[75,2115,2116],{},[81,2117,1754],{},[98,2119,1754],{"encoding":100},[59,2121,2123],{"className":2122,"ariaHidden":106},[105],[59,2124,2126,2129],{"className":2125},[110],[59,2127],{"className":2128,"style":1766},[114],[59,2130,1754],{"className":2131},[119],[59,2133,2135,2152],{"className":2134},[62],[59,2136,2138],{"className":2137},[66],[68,2139,2140],{"xmlns":70},[72,2141,2142,2150],{},[75,2143,2144],{},[88,2145,2146,2148],{},[91,2147,93],{},[91,2149,96],{},[98,2151,1219],{"encoding":100},[59,2153,2155],{"className":2154,"ariaHidden":106},[105],[59,2156,2158,2161],{"className":2157},[110],[59,2159],{"className":2160,"style":1229},[114],[59,2162,2164,2167],{"className":2163},[119],[59,2165,93],{"className":2166},[119,192],[59,2168,2170],{"className":2169},[196],[59,2171,2173,2193],{"className":2172},[131,132],[59,2174,2176,2190],{"className":2175},[136],[59,2177,2179],{"className":2178,"style":1248},[140],[59,2180,2181,2184],{"style":1251},[59,2182],{"className":2183,"style":540},[148],[59,2185,2187],{"className":2186},[153,154,155,156],[59,2188,96],{"className":2189,"style":222},[119,192,156],[59,2191,227],{"className":2192},[226],[59,2194,2196],{"className":2195},[136],[59,2197,2199],{"className":2198,"style":1270},[140],[59,2200],{},"，为了抵消这种影响，我们要乘以一个缩放因子 ",[59,2203,2205,2228],{"className":2204},[62],[59,2206,2208],{"className":2207},[66],[68,2209,2210],{"xmlns":70},[72,2211,2212,2226],{},[75,2213,2214],{},[78,2215,2216,2218],{},[81,2217,83],{},[85,2219,2220],{},[88,2221,2222,2224],{},[91,2223,93],{},[91,2225,96],{},[98,2227,101],{"encoding":100},[59,2229,2231],{"className":2230,"ariaHidden":106},[105],[59,2232,2234,2237],{"className":2233},[110],[59,2235],{"className":2236,"style":115},[114],[59,2238,2240,2243,2383],{"className":2239},[119],[59,2241],{"className":2242},[123,124],[59,2244,2246],{"className":2245},[78],[59,2247,2249,2375],{"className":2248},[131,132],[59,2250,2252,2372],{"className":2251},[136],[59,2253,2255,2350,2358],{"className":2254,"style":141},[140],[59,2256,2257,2260],{"style":144},[59,2258],{"className":2259,"style":149},[148],[59,2261,2263],{"className":2262},[153,154,155,156],[59,2264,2266],{"className":2265},[119,156],[59,2267,2269],{"className":2268},[119,163,156],[59,2270,2272,2342],{"className":2271},[131,132],[59,2273,2275,2339],{"className":2274},[136],[59,2276,2278,2327],{"className":2277,"style":173},[140],[59,2279,2281,2284],{"className":2280,"style":178},[177],[59,2282],{"className":2283,"style":149},[148],[59,2285,2287],{"className":2286,"style":185},[119,156],[59,2288,2290,2293],{"className":2289},[119,156],[59,2291,93],{"className":2292},[119,192,156],[59,2294,2296],{"className":2295},[196],[59,2297,2299,2319],{"className":2298},[131,132],[59,2300,2302,2316],{"className":2301},[136],[59,2303,2305],{"className":2304,"style":206},[140],[59,2306,2307,2310],{"style":209},[59,2308],{"className":2309,"style":213},[148],[59,2311,2313],{"className":2312},[153,217,218,156],[59,2314,96],{"className":2315,"style":222},[119,192,156],[59,2317,227],{"className":2318},[226],[59,2320,2322],{"className":2321},[136],[59,2323,2325],{"className":2324,"style":234},[140],[59,2326],{},[59,2328,2329,2332],{"style":239},[59,2330],{"className":2331,"style":149},[148],[59,2333,2335],{"className":2334,"style":247},[246,156],[249,2336,2337],{"xmlns":251,"width":252,"height":253,"viewBox":254,"preserveAspectRatio":255},[257,2338],{"d":259},[59,2340,227],{"className":2341},[226],[59,2343,2345],{"className":2344},[136],[59,2346,2348],{"className":2347,"style":269},[140],[59,2349],{},[59,2351,2352,2355],{"style":274},[59,2353],{"className":2354,"style":149},[148],[59,2356],{"className":2357,"style":282},[281],[59,2359,2360,2363],{"style":285},[59,2361],{"className":2362,"style":149},[148],[59,2364,2366],{"className":2365},[153,154,155,156],[59,2367,2369],{"className":2368},[119,156],[59,2370,83],{"className":2371},[119,156],[59,2373,227],{"className":2374},[226],[59,2376,2378],{"className":2377},[136],[59,2379,2381],{"className":2380,"style":307},[140],[59,2382],{},[59,2384],{"className":2385},[313,124],[15,2387,2388],{},"因为在训练和推理时，并非所有的值都应该加入注意力汇聚操作中，所以我们首先实现 masked_softmax 函数来保证在注意力汇聚时仅加入有意义的值。",[2390,2391,2396],"pre",{"className":2392,"code":2393,"language":2394,"meta":2395,"style":2395},"language-python shiki shiki-themes github-light dracula","def masked_softmax(X, valid_lens):\n    \"\"\"通过在最后一个轴上掩蔽元素来执行softmax操作\"\"\"\n    # X:3D张量，valid_lens:1D或2D张量\n    if valid_lens is None:\n        return nn.functional.softmax(X, dim=-1)\n    else:\n        shape = X.shape\n        if valid_lens.dim() == 1:\n            valid_lens = torch.repeat_interleave(valid_lens, shape[1])\n        else:\n            valid_lens = valid_lens.reshape(-1)\n        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换，从而其softmax输出为0\n        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,\n                              value=-1e6)\n        return nn.functional.softmax(X.reshape(shape), dim=-1)\n","python","",[2397,2398,2399,2427,2434,2441,2460,2481,2489,2500,2517,2533,2541,2558,2564,2589,2602],"code",{"__ignoreMap":2395},[59,2400,2403,2407,2411,2414,2418,2421,2424],{"class":2401,"line":2402},"line",1,[59,2404,2406],{"class":2405},"szJfE","def",[59,2408,2410],{"class":2409},"sCLZk"," masked_softmax",[59,2412,798],{"class":2413},"scbbO",[59,2415,2417],{"class":2416},"syNf4","X",[59,2419,2420],{"class":2413},", ",[59,2422,2423],{"class":2416},"valid_lens",[59,2425,2426],{"class":2413},"):\n",[59,2428,2430],{"class":2401,"line":2429},2,[59,2431,2433],{"class":2432},"seLWX","    \"\"\"通过在最后一个轴上掩蔽元素来执行softmax操作\"\"\"\n",[59,2435,2437],{"class":2401,"line":2436},3,[59,2438,2440],{"class":2439},"sfgPZ","    # X:3D张量，valid_lens:1D或2D张量\n",[59,2442,2444,2447,2450,2453,2457],{"class":2401,"line":2443},4,[59,2445,2446],{"class":2405},"    if",[59,2448,2449],{"class":2413}," valid_lens ",[59,2451,2452],{"class":2405},"is",[59,2454,2456],{"class":2455},"soDru"," None",[59,2458,2459],{"class":2413},":\n",[59,2461,2463,2466,2469,2473,2476,2478],{"class":2401,"line":2462},5,[59,2464,2465],{"class":2405},"        return",[59,2467,2468],{"class":2413}," nn.functional.softmax(X, ",[59,2470,2472],{"class":2471},"sQkXh","dim",[59,2474,2475],{"class":2405},"=-",[59,2477,83],{"class":2455},[59,2479,2480],{"class":2413},")\n",[59,2482,2484,2487],{"class":2401,"line":2483},6,[59,2485,2486],{"class":2405},"    else",[59,2488,2459],{"class":2413},[59,2490,2492,2495,2497],{"class":2401,"line":2491},7,[59,2493,2494],{"class":2413},"        shape ",[59,2496,815],{"class":2405},[59,2498,2499],{"class":2413}," X.shape\n",[59,2501,2503,2506,2509,2512,2515],{"class":2401,"line":2502},8,[59,2504,2505],{"class":2405},"        if",[59,2507,2508],{"class":2413}," valid_lens.dim() ",[59,2510,2511],{"class":2405},"==",[59,2513,2514],{"class":2455}," 1",[59,2516,2459],{"class":2413},[59,2518,2520,2523,2525,2528,2530],{"class":2401,"line":2519},9,[59,2521,2522],{"class":2413},"            valid_lens ",[59,2524,815],{"class":2405},[59,2526,2527],{"class":2413}," torch.repeat_interleave(valid_lens, shape[",[59,2529,83],{"class":2455},[59,2531,2532],{"class":2413},"])\n",[59,2534,2536,2539],{"class":2401,"line":2535},10,[59,2537,2538],{"class":2405},"        else",[59,2540,2459],{"class":2413},[59,2542,2544,2546,2548,2551,2554,2556],{"class":2401,"line":2543},11,[59,2545,2522],{"class":2413},[59,2547,815],{"class":2405},[59,2549,2550],{"class":2413}," valid_lens.reshape(",[59,2552,2553],{"class":2405},"-",[59,2555,83],{"class":2455},[59,2557,2480],{"class":2413},[59,2559,2561],{"class":2401,"line":2560},12,[59,2562,2563],{"class":2439},"        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换，从而其softmax输出为0\n",[59,2565,2567,2570,2572,2575,2577,2579,2582,2584,2586],{"class":2401,"line":2566},13,[59,2568,2569],{"class":2413},"        X ",[59,2571,815],{"class":2405},[59,2573,2574],{"class":2413}," d2l.sequence_mask(X.reshape(",[59,2576,2553],{"class":2405},[59,2578,83],{"class":2455},[59,2580,2581],{"class":2413},", shape[",[59,2583,2553],{"class":2405},[59,2585,83],{"class":2455},[59,2587,2588],{"class":2413},"]), valid_lens,\n",[59,2590,2592,2595,2597,2600],{"class":2401,"line":2591},14,[59,2593,2594],{"class":2471},"                              value",[59,2596,2475],{"class":2405},[59,2598,2599],{"class":2455},"1e6",[59,2601,2480],{"class":2413},[59,2603,2605,2607,2610,2612,2614,2616],{"class":2401,"line":2604},15,[59,2606,2465],{"class":2405},[59,2608,2609],{"class":2413}," nn.functional.softmax(X.reshape(shape), ",[59,2611,2472],{"class":2471},[59,2613,2475],{"class":2405},[59,2615,83],{"class":2455},[59,2617,2480],{"class":2413},[15,2619,2620],{},"下面实现缩放点积注意力。",[2390,2622,2624],{"className":2392,"code":2623,"language":2394,"meta":2395,"style":2395},"class DotProductAttention(nn.Module):\n    \"\"\"缩放点积注意力\"\"\"\n    def __init__(self, dropout, **kwargs):\n        super(DotProductAttention, self).__init__(**kwargs)\n        self.dropout = nn.Dropout(dropout)\n\n    # queries的形状：(batch_size，查询的个数，d)\n    # keys的形状：(batch_size，“键－值”对的个数，d)\n    # values的形状：(batch_size，“键－值”对的个数，值的维度)\n    # valid_lens的形状:(batch_size，)或者(batch_size，查询的个数)\n    def forward(self, queries, keys, values, valid_lens=None):\n        d = queries.shape[-1]\n        # 设置transpose_b=True为了交换keys的最后两个维度\n        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)\n        self.attention_weights = masked_softmax(scores, valid_lens)\n        return torch.bmm(self.dropout(self.attention_weights), values)\n",[2397,2625,2626,2648,2653,2682,2707,2720,2726,2731,2736,2741,2746,2783,2800,2805,2831,2843],{"__ignoreMap":2395},[59,2627,2628,2631,2635,2637,2641,2643,2646],{"class":2401,"line":2402},[59,2629,2630],{"class":2405},"class",[59,2632,2634],{"class":2633},"skCyd"," DotProductAttention",[59,2636,798],{"class":2413},[59,2638,2640],{"class":2639},"sDP9b","nn",[59,2642,864],{"class":2413},[59,2644,2645],{"class":2639},"Module",[59,2647,2426],{"class":2413},[59,2649,2650],{"class":2401,"line":2429},[59,2651,2652],{"class":2432},"    \"\"\"缩放点积注意力\"\"\"\n",[59,2654,2655,2658,2661,2663,2667,2669,2672,2674,2677,2680],{"class":2401,"line":2436},[59,2656,2657],{"class":2405},"    def",[59,2659,2660],{"class":2455}," __init__",[59,2662,798],{"class":2413},[59,2664,2666],{"class":2665},"sD3jg","self",[59,2668,2420],{"class":2413},[59,2670,2671],{"class":2416},"dropout",[59,2673,2420],{"class":2413},[59,2675,2676],{"class":2405},"**",[59,2678,2679],{"class":2416},"kwargs",[59,2681,2426],{"class":2413},[59,2683,2684,2688,2691,2694,2697,2700,2702,2704],{"class":2401,"line":2443},[59,2685,2687],{"class":2686},"sPGBF","        super",[59,2689,2690],{"class":2413},"(DotProductAttention, ",[59,2692,2666],{"class":2693},"sJti5",[59,2695,2696],{"class":2413},").",[59,2698,2699],{"class":2455},"__init__",[59,2701,798],{"class":2413},[59,2703,2676],{"class":2405},[59,2705,2706],{"class":2413},"kwargs)\n",[59,2708,2709,2712,2715,2717],{"class":2401,"line":2462},[59,2710,2711],{"class":2693},"        self",[59,2713,2714],{"class":2413},".dropout ",[59,2716,815],{"class":2405},[59,2718,2719],{"class":2413}," nn.Dropout(dropout)\n",[59,2721,2722],{"class":2401,"line":2483},[59,2723,2725],{"emptyLinePlaceholder":2724},true,"\n",[59,2727,2728],{"class":2401,"line":2491},[59,2729,2730],{"class":2439},"    # queries的形状：(batch_size，查询的个数，d)\n",[59,2732,2733],{"class":2401,"line":2502},[59,2734,2735],{"class":2439},"    # keys的形状：(batch_size，“键－值”对的个数，d)\n",[59,2737,2738],{"class":2401,"line":2519},[59,2739,2740],{"class":2439},"    # values的形状：(batch_size，“键－值”对的个数，值的维度)\n",[59,2742,2743],{"class":2401,"line":2535},[59,2744,2745],{"class":2439},"    # valid_lens的形状:(batch_size，)或者(batch_size，查询的个数)\n",[59,2747,2748,2750,2753,2755,2757,2759,2762,2764,2767,2769,2772,2774,2776,2778,2781],{"class":2401,"line":2543},[59,2749,2657],{"class":2405},[59,2751,2752],{"class":2409}," forward",[59,2754,798],{"class":2413},[59,2756,2666],{"class":2665},[59,2758,2420],{"class":2413},[59,2760,2761],{"class":2416},"queries",[59,2763,2420],{"class":2413},[59,2765,2766],{"class":2416},"keys",[59,2768,2420],{"class":2413},[59,2770,2771],{"class":2416},"values",[59,2773,2420],{"class":2413},[59,2775,2423],{"class":2416},[59,2777,815],{"class":2405},[59,2779,2780],{"class":2455},"None",[59,2782,2426],{"class":2413},[59,2784,2785,2788,2790,2793,2795,2797],{"class":2401,"line":2560},[59,2786,2787],{"class":2413},"        d ",[59,2789,815],{"class":2405},[59,2791,2792],{"class":2413}," queries.shape[",[59,2794,2553],{"class":2405},[59,2796,83],{"class":2455},[59,2798,2799],{"class":2413},"]\n",[59,2801,2802],{"class":2401,"line":2566},[59,2803,2804],{"class":2439},"        # 设置transpose_b=True为了交换keys的最后两个维度\n",[59,2806,2807,2810,2812,2815,2817,2819,2822,2825,2828],{"class":2401,"line":2591},[59,2808,2809],{"class":2413},"        scores ",[59,2811,815],{"class":2405},[59,2813,2814],{"class":2413}," torch.bmm(queries, keys.transpose(",[59,2816,83],{"class":2455},[59,2818,803],{"class":2413},[59,2820,2821],{"class":2455},"2",[59,2823,2824],{"class":2413},")) ",[59,2826,2827],{"class":2405},"/",[59,2829,2830],{"class":2413}," math.sqrt(d)\n",[59,2832,2833,2835,2838,2840],{"class":2401,"line":2604},[59,2834,2711],{"class":2693},[59,2836,2837],{"class":2413},".attention_weights ",[59,2839,815],{"class":2405},[59,2841,2842],{"class":2413}," masked_softmax(scores, valid_lens)\n",[59,2844,2846,2848,2851,2853,2856,2858],{"class":2401,"line":2845},16,[59,2847,2465],{"class":2405},[59,2849,2850],{"class":2413}," torch.bmm(",[59,2852,2666],{"class":2693},[59,2854,2855],{"class":2413},".dropout(",[59,2857,2666],{"class":2693},[59,2859,2860],{"class":2413},".attention_weights), values)\n",[15,2862,2863,2864,2867],{},"代码中函数的 ",[2397,2865,2866],{},"**kwargs","表示任意关键字参数，为了符合 Pytorch 的继承约定。",[49,2869,2870],{"id":2870},"多头注意力",[15,2872,2873],{},"多头注意力机制允许模型在不同位置上同时关注来自不同表示子空间的信息。使用单一注意力头时，平均操作会抑制这种能力。",[15,2875,2876],{},[41,2877],{"alt":2395,"src":2878},"transformer/MHA.png",[15,2880,2881],{},[59,2882,2884,3000],{"className":2883},[62],[59,2885,2887],{"className":2886},[66],[68,2888,2889],{"xmlns":70},[72,2890,2891,2997],{},[75,2892,2893,2917,2919,2921,2923,2925,2927,2929,2931,2933,2949,2951,2966,2968,2971,2973,2987,2989],{},[75,2894,2895,2898,2901,2904,2906,2908,2911,2913,2915],{},[91,2896,2897],{"mathvariant":818},"M",[91,2899,2900],{"mathvariant":818},"u",[91,2902,2903],{"mathvariant":818},"l",[91,2905,777],{"mathvariant":818},[91,2907,789],{"mathvariant":818},[91,2909,2910],{"mathvariant":818},"H",[91,2912,782],{"mathvariant":818},[91,2914,22],{"mathvariant":818},[91,2916,93],{"mathvariant":818},[461,2918,798],{"stretchy":797},[91,2920,459],{},[461,2922,803],{"separator":106},[91,2924,572],{},[461,2926,803],{"separator":106},[91,2928,672],{},[461,2930,812],{"stretchy":797},[461,2932,815],{},[75,2934,2935,2938,2940,2942,2945,2947],{},[91,2936,2937],{"mathvariant":818},"C",[91,2939,792],{"mathvariant":818},[91,2941,335],{"mathvariant":818},[91,2943,2944],{"mathvariant":818},"c",[91,2946,22],{"mathvariant":818},[91,2948,777],{"mathvariant":818},[461,2950,798],{"stretchy":797},[88,2952,2953,2964],{},[75,2954,2955,2958,2960,2962],{},[91,2956,2957],{"mathvariant":818},"h",[91,2959,782],{"mathvariant":818},[91,2961,22],{"mathvariant":818},[91,2963,93],{"mathvariant":818},[81,2965,83],{},[461,2967,803],{"separator":106},[461,2969,2970],{},"…",[461,2972,803],{"separator":106},[88,2974,2975,2985],{},[75,2976,2977,2979,2981,2983],{},[91,2978,2957],{"mathvariant":818},[91,2980,782],{"mathvariant":818},[91,2982,22],{"mathvariant":818},[91,2984,93],{"mathvariant":818},[91,2986,2957],{},[461,2988,812],{"stretchy":797},[465,2990,2991,2994],{},[91,2992,2993],{},"W",[91,2995,2996],{},"O",[98,2998,2999],{"encoding":100},"\\mathrm{MultiHead}(Q, K, V) = \\mathrm{Concat}(\\mathrm{head}_1, \\ldots, \\mathrm{head}_h) W^O",[59,3001,3003,3052],{"className":3002,"ariaHidden":106},[105],[59,3004,3006,3009,3016,3019,3022,3025,3028,3031,3034,3037,3040,3043,3046,3049],{"className":3005},[110],[59,3007],{"className":3008,"style":877},[114],[59,3010,3012],{"className":3011},[119],[59,3013,3015],{"className":3014},[119,955],"MultiHead",[59,3017,798],{"className":3018},[123],[59,3020,459],{"className":3021},[119,192],[59,3023,803],{"className":3024},[912],[59,3026],{"className":3027,"style":916},[499],[59,3029,572],{"className":3030,"style":603},[119,192],[59,3032,803],{"className":3033},[912],[59,3035],{"className":3036,"style":916},[499],[59,3038,672],{"className":3039,"style":702},[119,192],[59,3041,812],{"className":3042},[313],[59,3044],{"className":3045,"style":500},[499],[59,3047,815],{"className":3048},[504],[59,3050],{"className":3051,"style":500},[499],[59,3053,3055,3059,3066,3069,3115,3118,3121,3124,3127,3130,3133,3176,3179],{"className":3054},[110],[59,3056],{"className":3057,"style":3058},[114],"height:1.0913em;vertical-align:-0.25em;",[59,3060,3062],{"className":3061},[119],[59,3063,3065],{"className":3064},[119,955],"Concat",[59,3067,798],{"className":3068},[123],[59,3070,3072,3079],{"className":3071},[119],[59,3073,3075],{"className":3074},[119],[59,3076,3078],{"className":3077},[119,955],"head",[59,3080,3082],{"className":3081},[196],[59,3083,3085,3107],{"className":3084},[131,132],[59,3086,3088,3104],{"className":3087},[136],[59,3089,3092],{"className":3090,"style":3091},[140],"height:0.3011em;",[59,3093,3095,3098],{"style":3094},"top:-2.55em;margin-right:0.05em;",[59,3096],{"className":3097,"style":540},[148],[59,3099,3101],{"className":3100},[153,154,155,156],[59,3102,83],{"className":3103},[119,156],[59,3105,227],{"className":3106},[226],[59,3108,3110],{"className":3109},[136],[59,3111,3113],{"className":3112,"style":1270},[140],[59,3114],{},[59,3116,803],{"className":3117},[912],[59,3119],{"className":3120,"style":916},[499],[59,3122,2970],{"className":3123},[962],[59,3125],{"className":3126,"style":916},[499],[59,3128,803],{"className":3129},[912],[59,3131],{"className":3132,"style":916},[499],[59,3134,3136,3142],{"className":3135},[119],[59,3137,3139],{"className":3138},[119],[59,3140,3078],{"className":3141},[119,955],[59,3143,3145],{"className":3144},[196],[59,3146,3148,3168],{"className":3147},[131,132],[59,3149,3151,3165],{"className":3150},[136],[59,3152,3154],{"className":3153,"style":1248},[140],[59,3155,3156,3159],{"style":3094},[59,3157],{"className":3158,"style":540},[148],[59,3160,3162],{"className":3161},[153,154,155,156],[59,3163,2957],{"className":3164},[119,192,156],[59,3166,227],{"className":3167},[226],[59,3169,3171],{"className":3170},[136],[59,3172,3174],{"className":3173,"style":1270},[140],[59,3175],{},[59,3177,812],{"className":3178},[313],[59,3180,3182,3186],{"className":3181},[119],[59,3183,2993],{"className":3184,"style":3185},[119,192],"margin-right:0.13889em;",[59,3187,3189],{"className":3188},[196],[59,3190,3192],{"className":3191},[131],[59,3193,3195],{"className":3194},[136],[59,3196,3199],{"className":3197,"style":3198},[140],"height:0.8413em;",[59,3200,3201,3204],{"style":536},[59,3202],{"className":3203,"style":540},[148],[59,3205,3207],{"className":3206},[153,154,155,156],[59,3208,2996],{"className":3209,"style":3210},[119,192,156],"margin-right:0.02778em;",[15,3212,3213,3214,3576,3577,3826,3827,3826,4076,4327,4328,864],{},"其中 ",[59,3215,3217,3309],{"className":3216},[62],[59,3218,3220],{"className":3219},[66],[68,3221,3222],{"xmlns":70},[72,3223,3224,3306],{},[75,3225,3226,3240,3242,3262,3264,3266,3274,3276,3280,3282,3290,3292,3294,3296,3304],{},[88,3227,3228,3238],{},[75,3229,3230,3232,3234,3236],{},[91,3231,2957],{"mathvariant":818},[91,3233,782],{"mathvariant":818},[91,3235,22],{"mathvariant":818},[91,3237,93],{"mathvariant":818},[91,3239,789],{},[461,3241,815],{},[75,3243,3244,3246,3248,3250,3252,3254,3256,3258,3260],{},[91,3245,774],{"mathvariant":818},[91,3247,777],{"mathvariant":818},[91,3249,777],{"mathvariant":818},[91,3251,782],{"mathvariant":818},[91,3253,335],{"mathvariant":818},[91,3255,777],{"mathvariant":818},[91,3257,789],{"mathvariant":818},[91,3259,792],{"mathvariant":818},[91,3261,335],{"mathvariant":818},[461,3263,798],{"stretchy":797},[91,3265,459],{},[1822,3267,3268,3270,3272],{},[91,3269,2993],{},[91,3271,789],{},[91,3273,459],{},[461,3275,803],{"separator":106},[3277,3278,3279],"mtext",{},"  ",[91,3281,572],{},[1822,3283,3284,3286,3288],{},[91,3285,2993],{},[91,3287,789],{},[91,3289,572],{},[461,3291,803],{"separator":106},[3277,3293,3279],{},[91,3295,672],{},[1822,3297,3298,3300,3302],{},[91,3299,2993],{},[91,3301,789],{},[91,3303,672],{},[461,3305,812],{"stretchy":797},[98,3307,3308],{"encoding":100},"\\mathrm{head}_i=\\mathrm{Attention}(QW_i^Q,\\; KW_i^K,\\; VW_i^V)",[59,3310,3312,3370],{"className":3311,"ariaHidden":106},[105],[59,3313,3315,3318,3361,3364,3367],{"className":3314},[110],[59,3316],{"className":3317,"style":1229},[114],[59,3319,3321,3327],{"className":3320},[119],[59,3322,3324],{"className":3323},[119],[59,3325,3078],{"className":3326},[119,955],[59,3328,3330],{"className":3329},[196],[59,3331,3333,3353],{"className":3332},[131,132],[59,3334,3336,3350],{"className":3335},[136],[59,3337,3339],{"className":3338,"style":2032},[140],[59,3340,3341,3344],{"style":3094},[59,3342],{"className":3343,"style":540},[148],[59,3345,3347],{"className":3346},[153,154,155,156],[59,3348,789],{"className":3349},[119,192,156],[59,3351,227],{"className":3352},[226],[59,3354,3356],{"className":3355},[136],[59,3357,3359],{"className":3358,"style":1270},[140],[59,3360],{},[59,3362],{"className":3363,"style":500},[499],[59,3365,815],{"className":3366},[504],[59,3368],{"className":3369,"style":500},[499],[59,3371,3373,3377,3384,3387,3390,3445,3448,3451,3454,3457,3510,3513,3516,3519,3522,3573],{"className":3372},[110],[59,3374],{"className":3375,"style":3376},[114],"height:1.2361em;vertical-align:-0.2769em;",[59,3378,3380],{"className":3379},[119],[59,3381,3383],{"className":3382},[119,955],"Attention",[59,3385,798],{"className":3386},[123],[59,3388,459],{"className":3389},[119,192],[59,3391,3393,3396],{"className":3392},[119],[59,3394,2993],{"className":3395,"style":3185},[119,192],[59,3397,3399],{"className":3398},[196],[59,3400,3402,3436],{"className":3401},[131,132],[59,3403,3405,3433],{"className":3404},[136],[59,3406,3409,3421],{"className":3407,"style":3408},[140],"height:0.9592em;",[59,3410,3412,3415],{"style":3411},"top:-2.4231em;margin-left:-0.1389em;margin-right:0.05em;",[59,3413],{"className":3414,"style":540},[148],[59,3416,3418],{"className":3417},[153,154,155,156],[59,3419,789],{"className":3420},[119,192,156],[59,3422,3424,3427],{"style":3423},"top:-3.1809em;margin-right:0.05em;",[59,3425],{"className":3426,"style":540},[148],[59,3428,3430],{"className":3429},[153,154,155,156],[59,3431,459],{"className":3432},[119,192,156],[59,3434,227],{"className":3435},[226],[59,3437,3439],{"className":3438},[136],[59,3440,3443],{"className":3441,"style":3442},[140],"height:0.2769em;",[59,3444],{},[59,3446,803],{"className":3447},[912],[59,3449],{"className":3450,"style":500},[499],[59,3452],{"className":3453,"style":916},[499],[59,3455,572],{"className":3456,"style":603},[119,192],[59,3458,3460,3463],{"className":3459},[119],[59,3461,2993],{"className":3462,"style":3185},[119,192],[59,3464,3466],{"className":3465},[196],[59,3467,3469,3501],{"className":3468},[131,132],[59,3470,3472,3498],{"className":3471},[136],[59,3473,3475,3487],{"className":3474,"style":3198},[140],[59,3476,3478,3481],{"style":3477},"top:-2.4413em;margin-left:-0.1389em;margin-right:0.05em;",[59,3479],{"className":3480,"style":540},[148],[59,3482,3484],{"className":3483},[153,154,155,156],[59,3485,789],{"className":3486},[119,192,156],[59,3488,3489,3492],{"style":536},[59,3490],{"className":3491,"style":540},[148],[59,3493,3495],{"className":3494},[153,154,155,156],[59,3496,572],{"className":3497,"style":603},[119,192,156],[59,3499,227],{"className":3500},[226],[59,3502,3504],{"className":3503},[136],[59,3505,3508],{"className":3506,"style":3507},[140],"height:0.2587em;",[59,3509],{},[59,3511,803],{"className":3512},[912],[59,3514],{"className":3515,"style":500},[499],[59,3517],{"className":3518,"style":916},[499],[59,3520,672],{"className":3521,"style":702},[119,192],[59,3523,3525,3528],{"className":3524},[119],[59,3526,2993],{"className":3527,"style":3185},[119,192],[59,3529,3531],{"className":3530},[196],[59,3532,3534,3565],{"className":3533},[131,132],[59,3535,3537,3562],{"className":3536},[136],[59,3538,3540,3551],{"className":3539,"style":3198},[140],[59,3541,3542,3545],{"style":3477},[59,3543],{"className":3544,"style":540},[148],[59,3546,3548],{"className":3547},[153,154,155,156],[59,3549,789],{"className":3550},[119,192,156],[59,3552,3553,3556],{"style":536},[59,3554],{"className":3555,"style":540},[148],[59,3557,3559],{"className":3558},[153,154,155,156],[59,3560,672],{"className":3561,"style":702},[119,192,156],[59,3563,227],{"className":3564},[226],[59,3566,3568],{"className":3567},[136],[59,3569,3571],{"className":3570,"style":3507},[140],[59,3572],{},[59,3574,812],{"className":3575},[313],"，投影参数是矩阵\n",[59,3578,3580,3632],{"className":3579},[62],[59,3581,3583],{"className":3582},[66],[68,3584,3585],{"xmlns":70},[72,3586,3587,3629],{},[75,3588,3589,3597,3599],{},[1822,3590,3591,3593,3595],{},[91,3592,2993],{},[91,3594,789],{},[91,3596,459],{},[461,3598,463],{},[465,3600,3601,3603],{},[91,3602,470],{"mathvariant":469},[75,3604,3605,3621,3623],{},[88,3606,3607,3609],{},[91,3608,93],{},[75,3610,3611,3613,3615,3617,3619],{},[91,3612,366],{"mathvariant":818},[91,3614,792],{"mathvariant":818},[91,3616,93],{"mathvariant":818},[91,3618,782],{"mathvariant":818},[91,3620,2903],{"mathvariant":818},[461,3622,477],{},[88,3624,3625,3627],{},[91,3626,93],{},[91,3628,96],{},[98,3630,3631],{"encoding":100},"W_i^Q \\in \\mathbb{R}^{d_{\\mathrm{model}} \\times d_k}",[59,3633,3635,3701],{"className":3634,"ariaHidden":106},[105],[59,3636,3638,3641,3692,3695,3698],{"className":3637},[110],[59,3639],{"className":3640,"style":3376},[114],[59,3642,3644,3647],{"className":3643},[119],[59,3645,2993],{"className":3646,"style":3185},[119,192],[59,3648,3650],{"className":3649},[196],[59,3651,3653,3684],{"className":3652},[131,132],[59,3654,3656,3681],{"className":3655},[136],[59,3657,3659,3670],{"className":3658,"style":3408},[140],[59,3660,3661,3664],{"style":3411},[59,3662],{"className":3663,"style":540},[148],[59,3665,3667],{"className":3666},[153,154,155,156],[59,3668,789],{"className":3669},[119,192,156],[59,3671,3672,3675],{"style":3423},[59,3673],{"className":3674,"style":540},[148],[59,3676,3678],{"className":3677},[153,154,155,156],[59,3679,459],{"className":3680},[119,192,156],[59,3682,227],{"className":3683},[226],[59,3685,3687],{"className":3686},[136],[59,3688,3690],{"className":3689,"style":3442},[140],[59,3691],{},[59,3693],{"className":3694,"style":500},[499],[59,3696,463],{"className":3697},[504],[59,3699],{"className":3700,"style":500},[499],[59,3702,3704,3707],{"className":3703},[110],[59,3705],{"className":3706,"style":514},[114],[59,3708,3710,3713],{"className":3709},[119],[59,3711,470],{"className":3712},[119,521],[59,3714,3716],{"className":3715},[196],[59,3717,3719],{"className":3718},[131],[59,3720,3722],{"className":3721},[136],[59,3723,3725],{"className":3724,"style":514},[140],[59,3726,3727,3730],{"style":536},[59,3728],{"className":3729,"style":540},[148],[59,3731,3733],{"className":3732},[153,154,155,156],[59,3734,3736,3783,3786],{"className":3735},[119,156],[59,3737,3739,3742],{"className":3738},[119,156],[59,3740,93],{"className":3741},[119,192,156],[59,3743,3745],{"className":3744},[196],[59,3746,3748,3775],{"className":3747},[131,132],[59,3749,3751,3772],{"className":3750},[136],[59,3752,3754],{"className":3753,"style":206},[140],[59,3755,3756,3759],{"style":209},[59,3757],{"className":3758,"style":213},[148],[59,3760,3762],{"className":3761},[153,217,218,156],[59,3763,3765],{"className":3764},[119,156],[59,3766,3768],{"className":3767},[119,156],[59,3769,3771],{"className":3770},[119,955,156],"model",[59,3773,227],{"className":3774},[226],[59,3776,3778],{"className":3777},[136],[59,3779,3781],{"className":3780,"style":234},[140],[59,3782],{},[59,3784,477],{"className":3785},[553,156],[59,3787,3789,3792],{"className":3788},[119,156],[59,3790,93],{"className":3791},[119,192,156],[59,3793,3795],{"className":3794},[196],[59,3796,3798,3818],{"className":3797},[131,132],[59,3799,3801,3815],{"className":3800},[136],[59,3802,3804],{"className":3803,"style":206},[140],[59,3805,3806,3809],{"style":209},[59,3807],{"className":3808,"style":213},[148],[59,3810,3812],{"className":3811},[153,217,218,156],[59,3813,96],{"className":3814,"style":222},[119,192,156],[59,3816,227],{"className":3817},[226],[59,3819,3821],{"className":3820},[136],[59,3822,3824],{"className":3823,"style":234},[140],[59,3825],{},",\n",[59,3828,3830,3882],{"className":3829},[62],[59,3831,3833],{"className":3832},[66],[68,3834,3835],{"xmlns":70},[72,3836,3837,3879],{},[75,3838,3839,3847,3849],{},[1822,3840,3841,3843,3845],{},[91,3842,2993],{},[91,3844,789],{},[91,3846,572],{},[461,3848,463],{},[465,3850,3851,3853],{},[91,3852,470],{"mathvariant":469},[75,3854,3855,3871,3873],{},[88,3856,3857,3859],{},[91,3858,93],{},[75,3860,3861,3863,3865,3867,3869],{},[91,3862,366],{"mathvariant":818},[91,3864,792],{"mathvariant":818},[91,3866,93],{"mathvariant":818},[91,3868,782],{"mathvariant":818},[91,3870,2903],{"mathvariant":818},[461,3872,477],{},[88,3874,3875,3877],{},[91,3876,93],{},[91,3878,96],{},[98,3880,3881],{"encoding":100},"W_i^K \\in \\mathbb{R}^{d_{\\mathrm{model}} \\times d_k}",[59,3883,3885,3952],{"className":3884,"ariaHidden":106},[105],[59,3886,3888,3892,3943,3946,3949],{"className":3887},[110],[59,3889],{"className":3890,"style":3891},[114],"height:1.1em;vertical-align:-0.2587em;",[59,3893,3895,3898],{"className":3894},[119],[59,3896,2993],{"className":3897,"style":3185},[119,192],[59,3899,3901],{"className":3900},[196],[59,3902,3904,3935],{"className":3903},[131,132],[59,3905,3907,3932],{"className":3906},[136],[59,3908,3910,3921],{"className":3909,"style":3198},[140],[59,3911,3912,3915],{"style":3477},[59,3913],{"className":3914,"style":540},[148],[59,3916,3918],{"className":3917},[153,154,155,156],[59,3919,789],{"className":3920},[119,192,156],[59,3922,3923,3926],{"style":536},[59,3924],{"className":3925,"style":540},[148],[59,3927,3929],{"className":3928},[153,154,155,156],[59,3930,572],{"className":3931,"style":603},[119,192,156],[59,3933,227],{"className":3934},[226],[59,3936,3938],{"className":3937},[136],[59,3939,3941],{"className":3940,"style":3507},[140],[59,3942],{},[59,3944],{"className":3945,"style":500},[499],[59,3947,463],{"className":3948},[504],[59,3950],{"className":3951,"style":500},[499],[59,3953,3955,3958],{"className":3954},[110],[59,3956],{"className":3957,"style":514},[114],[59,3959,3961,3964],{"className":3960},[119],[59,3962,470],{"className":3963},[119,521],[59,3965,3967],{"className":3966},[196],[59,3968,3970],{"className":3969},[131],[59,3971,3973],{"className":3972},[136],[59,3974,3976],{"className":3975,"style":514},[140],[59,3977,3978,3981],{"style":536},[59,3979],{"className":3980,"style":540},[148],[59,3982,3984],{"className":3983},[153,154,155,156],[59,3985,3987,4033,4036],{"className":3986},[119,156],[59,3988,3990,3993],{"className":3989},[119,156],[59,3991,93],{"className":3992},[119,192,156],[59,3994,3996],{"className":3995},[196],[59,3997,3999,4025],{"className":3998},[131,132],[59,4000,4002,4022],{"className":4001},[136],[59,4003,4005],{"className":4004,"style":206},[140],[59,4006,4007,4010],{"style":209},[59,4008],{"className":4009,"style":213},[148],[59,4011,4013],{"className":4012},[153,217,218,156],[59,4014,4016],{"className":4015},[119,156],[59,4017,4019],{"className":4018},[119,156],[59,4020,3771],{"className":4021},[119,955,156],[59,4023,227],{"className":4024},[226],[59,4026,4028],{"className":4027},[136],[59,4029,4031],{"className":4030,"style":234},[140],[59,4032],{},[59,4034,477],{"className":4035},[553,156],[59,4037,4039,4042],{"className":4038},[119,156],[59,4040,93],{"className":4041},[119,192,156],[59,4043,4045],{"className":4044},[196],[59,4046,4048,4068],{"className":4047},[131,132],[59,4049,4051,4065],{"className":4050},[136],[59,4052,4054],{"className":4053,"style":206},[140],[59,4055,4056,4059],{"style":209},[59,4057],{"className":4058,"style":213},[148],[59,4060,4062],{"className":4061},[153,217,218,156],[59,4063,96],{"className":4064,"style":222},[119,192,156],[59,4066,227],{"className":4067},[226],[59,4069,4071],{"className":4070},[136],[59,4072,4074],{"className":4073,"style":234},[140],[59,4075],{},[59,4077,4079,4131],{"className":4078},[62],[59,4080,4082],{"className":4081},[66],[68,4083,4084],{"xmlns":70},[72,4085,4086,4128],{},[75,4087,4088,4096,4098],{},[1822,4089,4090,4092,4094],{},[91,4091,2993],{},[91,4093,789],{},[91,4095,672],{},[461,4097,463],{},[465,4099,4100,4102],{},[91,4101,470],{"mathvariant":469},[75,4103,4104,4120,4122],{},[88,4105,4106,4108],{},[91,4107,93],{},[75,4109,4110,4112,4114,4116,4118],{},[91,4111,366],{"mathvariant":818},[91,4113,792],{"mathvariant":818},[91,4115,93],{"mathvariant":818},[91,4117,782],{"mathvariant":818},[91,4119,2903],{"mathvariant":818},[461,4121,477],{},[88,4123,4124,4126],{},[91,4125,93],{},[91,4127,426],{},[98,4129,4130],{"encoding":100},"W_i^V \\in \\mathbb{R}^{d_{\\mathrm{model}} \\times d_v}",[59,4132,4134,4200],{"className":4133,"ariaHidden":106},[105],[59,4135,4137,4140,4191,4194,4197],{"className":4136},[110],[59,4138],{"className":4139,"style":3891},[114],[59,4141,4143,4146],{"className":4142},[119],[59,4144,2993],{"className":4145,"style":3185},[119,192],[59,4147,4149],{"className":4148},[196],[59,4150,4152,4183],{"className":4151},[131,132],[59,4153,4155,4180],{"className":4154},[136],[59,4156,4158,4169],{"className":4157,"style":3198},[140],[59,4159,4160,4163],{"style":3477},[59,4161],{"className":4162,"style":540},[148],[59,4164,4166],{"className":4165},[153,154,155,156],[59,4167,789],{"className":4168},[119,192,156],[59,4170,4171,4174],{"style":536},[59,4172],{"className":4173,"style":540},[148],[59,4175,4177],{"className":4176},[153,154,155,156],[59,4178,672],{"className":4179,"style":702},[119,192,156],[59,4181,227],{"className":4182},[226],[59,4184,4186],{"className":4185},[136],[59,4187,4189],{"className":4188,"style":3507},[140],[59,4190],{},[59,4192],{"className":4193,"style":500},[499],[59,4195,463],{"className":4196},[504],[59,4198],{"className":4199,"style":500},[499],[59,4201,4203,4206],{"className":4202},[110],[59,4204],{"className":4205,"style":514},[114],[59,4207,4209,4212],{"className":4208},[119],[59,4210,470],{"className":4211},[119,521],[59,4213,4215],{"className":4214},[196],[59,4216,4218],{"className":4217},[131],[59,4219,4221],{"className":4220},[136],[59,4222,4224],{"className":4223,"style":514},[140],[59,4225,4226,4229],{"style":536},[59,4227],{"className":4228,"style":540},[148],[59,4230,4232],{"className":4231},[153,154,155,156],[59,4233,4235,4281,4284],{"className":4234},[119,156],[59,4236,4238,4241],{"className":4237},[119,156],[59,4239,93],{"className":4240},[119,192,156],[59,4242,4244],{"className":4243},[196],[59,4245,4247,4273],{"className":4246},[131,132],[59,4248,4250,4270],{"className":4249},[136],[59,4251,4253],{"className":4252,"style":206},[140],[59,4254,4255,4258],{"style":209},[59,4256],{"className":4257,"style":213},[148],[59,4259,4261],{"className":4260},[153,217,218,156],[59,4262,4264],{"className":4263},[119,156],[59,4265,4267],{"className":4266},[119,156],[59,4268,3771],{"className":4269},[119,955,156],[59,4271,227],{"className":4272},[226],[59,4274,4276],{"className":4275},[136],[59,4277,4279],{"className":4278,"style":234},[140],[59,4280],{},[59,4282,477],{"className":4283},[553,156],[59,4285,4287,4290],{"className":4286},[119,156],[59,4288,93],{"className":4289},[119,192,156],[59,4291,4293],{"className":4292},[196],[59,4294,4296,4318],{"className":4295},[131,132],[59,4297,4299,4315],{"className":4298},[136],[59,4300,4303],{"className":4301,"style":4302},[140],"height:0.1645em;",[59,4304,4306,4309],{"style":4305},"top:-2.357em;margin-left:0em;margin-right:0.0714em;",[59,4307],{"className":4308,"style":213},[148],[59,4310,4312],{"className":4311},[153,217,218,156],[59,4313,426],{"className":4314,"style":441},[119,192,156],[59,4316,227],{"className":4317},[226],[59,4319,4321],{"className":4320},[136],[59,4322,4325],{"className":4323,"style":4324},[140],"height:0.143em;",[59,4326],{},"\n和\n",[59,4329,4331,4383],{"className":4330},[62],[59,4332,4334],{"className":4333},[66],[68,4335,4336],{"xmlns":70},[72,4337,4338,4380],{},[75,4339,4340,4346,4348],{},[465,4341,4342,4344],{},[91,4343,2993],{},[91,4345,2996],{},[461,4347,463],{},[465,4349,4350,4352],{},[91,4351,470],{"mathvariant":469},[75,4353,4354,4356,4362,4364],{},[91,4355,2957],{},[88,4357,4358,4360],{},[91,4359,93],{},[91,4361,426],{},[461,4363,477],{},[88,4365,4366,4368],{},[91,4367,93],{},[75,4369,4370,4372,4374,4376,4378],{},[91,4371,366],{"mathvariant":818},[91,4373,792],{"mathvariant":818},[91,4375,93],{"mathvariant":818},[91,4377,782],{"mathvariant":818},[91,4379,2903],{"mathvariant":818},[98,4381,4382],{"encoding":100},"W^O \\in \\mathbb{R}^{h d_v \\times d_{\\mathrm{model}}}",[59,4384,4386,4431],{"className":4385,"ariaHidden":106},[105],[59,4387,4389,4393,4422,4425,4428],{"className":4388},[110],[59,4390],{"className":4391,"style":4392},[114],"height:0.8804em;vertical-align:-0.0391em;",[59,4394,4396,4399],{"className":4395},[119],[59,4397,2993],{"className":4398,"style":3185},[119,192],[59,4400,4402],{"className":4401},[196],[59,4403,4405],{"className":4404},[131],[59,4406,4408],{"className":4407},[136],[59,4409,4411],{"className":4410,"style":3198},[140],[59,4412,4413,4416],{"style":536},[59,4414],{"className":4415,"style":540},[148],[59,4417,4419],{"className":4418},[153,154,155,156],[59,4420,2996],{"className":4421,"style":3210},[119,192,156],[59,4423],{"className":4424,"style":500},[499],[59,4426,463],{"className":4427},[504],[59,4429],{"className":4430,"style":500},[499],[59,4432,4434,4437],{"className":4433},[110],[59,4435],{"className":4436,"style":514},[114],[59,4438,4440,4443],{"className":4439},[119],[59,4441,470],{"className":4442},[119,521],[59,4444,4446],{"className":4445},[196],[59,4447,4449],{"className":4448},[131],[59,4450,4452],{"className":4451},[136],[59,4453,4455],{"className":4454,"style":514},[140],[59,4456,4457,4460],{"style":536},[59,4458],{"className":4459,"style":540},[148],[59,4461,4463],{"className":4462},[153,154,155,156],[59,4464,4466,4469,4509,4512],{"className":4465},[119,156],[59,4467,2957],{"className":4468},[119,192,156],[59,4470,4472,4475],{"className":4471},[119,156],[59,4473,93],{"className":4474},[119,192,156],[59,4476,4478],{"className":4477},[196],[59,4479,4481,4501],{"className":4480},[131,132],[59,4482,4484,4498],{"className":4483},[136],[59,4485,4487],{"className":4486,"style":4302},[140],[59,4488,4489,4492],{"style":4305},[59,4490],{"className":4491,"style":213},[148],[59,4493,4495],{"className":4494},[153,217,218,156],[59,4496,426],{"className":4497,"style":441},[119,192,156],[59,4499,227],{"className":4500},[226],[59,4502,4504],{"className":4503},[136],[59,4505,4507],{"className":4506,"style":4324},[140],[59,4508],{},[59,4510,477],{"className":4511},[553,156],[59,4513,4515,4518],{"className":4514},[119,156],[59,4516,93],{"className":4517},[119,192,156],[59,4519,4521],{"className":4520},[196],[59,4522,4524,4550],{"className":4523},[131,132],[59,4525,4527,4547],{"className":4526},[136],[59,4528,4530],{"className":4529,"style":206},[140],[59,4531,4532,4535],{"style":209},[59,4533],{"className":4534,"style":213},[148],[59,4536,4538],{"className":4537},[153,217,218,156],[59,4539,4541],{"className":4540},[119,156],[59,4542,4544],{"className":4543},[119,156],[59,4545,3771],{"className":4546},[119,955,156],[59,4548,227],{"className":4549},[226],[59,4551,4553],{"className":4552},[136],[59,4554,4556],{"className":4555,"style":234},[140],[59,4557],{},[15,4559,4560,4561,4564,4565,4568,4569,4571,4572,4574],{},"为了能使多个头并行计算，我们定义两个转置函数 ",[2397,4562,4563],{},"transpose_output"," 和 ",[2397,4566,4567],{},"transpose_qkv","，其中 ",[2397,4570,4563],{}," 还原了 ",[2397,4573,4567],{}," 的结果。",[2390,4576,4578],{"className":2392,"code":4577,"language":2394,"meta":2395,"style":2395},"def transpose_qkv(X, num_heads):\n    \"\"\"为了多注意力头的并行计算而变换形状\"\"\"\n    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)\n    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，num_hiddens/num_heads)\n    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n\n    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数, num_hiddens/num_heads)\n    X = X.permute(0, 2, 1, 3)\n\n    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数, num_hiddens/num_heads)\n    return X.reshape(-1, X.shape[2], X.shape[3])\n\ndef transpose_output(X, num_heads):\n    \"\"\"逆转transpose_qkv函数的操作\"\"\"\n    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n    X = X.permute(0, 2, 1, 3)\n    return X.reshape(X.shape[0], X.shape[1], -1)\n",[2397,4579,4580,4598,4603,4608,4613,4639,4643,4648,4674,4678,4683,4706,4710,4727,4732,4755,4779],{"__ignoreMap":2395},[59,4581,4582,4584,4587,4589,4591,4593,4596],{"class":2401,"line":2402},[59,4583,2406],{"class":2405},[59,4585,4586],{"class":2409}," transpose_qkv",[59,4588,798],{"class":2413},[59,4590,2417],{"class":2416},[59,4592,2420],{"class":2413},[59,4594,4595],{"class":2416},"num_heads",[59,4597,2426],{"class":2413},[59,4599,4600],{"class":2401,"line":2429},[59,4601,4602],{"class":2432},"    \"\"\"为了多注意力头的并行计算而变换形状\"\"\"\n",[59,4604,4605],{"class":2401,"line":2436},[59,4606,4607],{"class":2439},"    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)\n",[59,4609,4610],{"class":2401,"line":2443},[59,4611,4612],{"class":2439},"    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，num_hiddens/num_heads)\n",[59,4614,4615,4618,4620,4623,4625,4628,4630,4633,4635,4637],{"class":2401,"line":2462},[59,4616,4617],{"class":2413},"    X ",[59,4619,815],{"class":2405},[59,4621,4622],{"class":2413}," X.reshape(X.shape[",[59,4624,1754],{"class":2455},[59,4626,4627],{"class":2413},"], X.shape[",[59,4629,83],{"class":2455},[59,4631,4632],{"class":2413},"], num_heads, ",[59,4634,2553],{"class":2405},[59,4636,83],{"class":2455},[59,4638,2480],{"class":2413},[59,4640,4641],{"class":2401,"line":2483},[59,4642,2725],{"emptyLinePlaceholder":2724},[59,4644,4645],{"class":2401,"line":2491},[59,4646,4647],{"class":2439},"    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数, num_hiddens/num_heads)\n",[59,4649,4650,4652,4654,4657,4659,4661,4663,4665,4667,4669,4672],{"class":2401,"line":2502},[59,4651,4617],{"class":2413},[59,4653,815],{"class":2405},[59,4655,4656],{"class":2413}," X.permute(",[59,4658,1754],{"class":2455},[59,4660,2420],{"class":2413},[59,4662,2821],{"class":2455},[59,4664,2420],{"class":2413},[59,4666,83],{"class":2455},[59,4668,2420],{"class":2413},[59,4670,4671],{"class":2455},"3",[59,4673,2480],{"class":2413},[59,4675,4676],{"class":2401,"line":2519},[59,4677,2725],{"emptyLinePlaceholder":2724},[59,4679,4680],{"class":2401,"line":2535},[59,4681,4682],{"class":2439},"    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数, num_hiddens/num_heads)\n",[59,4684,4685,4688,4691,4693,4695,4698,4700,4702,4704],{"class":2401,"line":2543},[59,4686,4687],{"class":2405},"    return",[59,4689,4690],{"class":2413}," X.reshape(",[59,4692,2553],{"class":2405},[59,4694,83],{"class":2455},[59,4696,4697],{"class":2413},", X.shape[",[59,4699,2821],{"class":2455},[59,4701,4627],{"class":2413},[59,4703,4671],{"class":2455},[59,4705,2532],{"class":2413},[59,4707,4708],{"class":2401,"line":2560},[59,4709,2725],{"emptyLinePlaceholder":2724},[59,4711,4712,4714,4717,4719,4721,4723,4725],{"class":2401,"line":2566},[59,4713,2406],{"class":2405},[59,4715,4716],{"class":2409}," transpose_output",[59,4718,798],{"class":2413},[59,4720,2417],{"class":2416},[59,4722,2420],{"class":2413},[59,4724,4595],{"class":2416},[59,4726,2426],{"class":2413},[59,4728,4729],{"class":2401,"line":2591},[59,4730,4731],{"class":2432},"    \"\"\"逆转transpose_qkv函数的操作\"\"\"\n",[59,4733,4734,4736,4738,4740,4742,4744,4747,4749,4751,4753],{"class":2401,"line":2604},[59,4735,4617],{"class":2413},[59,4737,815],{"class":2405},[59,4739,4690],{"class":2413},[59,4741,2553],{"class":2405},[59,4743,83],{"class":2455},[59,4745,4746],{"class":2413},", num_heads, X.shape[",[59,4748,83],{"class":2455},[59,4750,4627],{"class":2413},[59,4752,2821],{"class":2455},[59,4754,2532],{"class":2413},[59,4756,4757,4759,4761,4763,4765,4767,4769,4771,4773,4775,4777],{"class":2401,"line":2845},[59,4758,4617],{"class":2413},[59,4760,815],{"class":2405},[59,4762,4656],{"class":2413},[59,4764,1754],{"class":2455},[59,4766,2420],{"class":2413},[59,4768,2821],{"class":2455},[59,4770,2420],{"class":2413},[59,4772,83],{"class":2455},[59,4774,2420],{"class":2413},[59,4776,4671],{"class":2455},[59,4778,2480],{"class":2413},[59,4780,4782,4784,4786,4788,4790,4792,4795,4797,4799],{"class":2401,"line":4781},17,[59,4783,4687],{"class":2405},[59,4785,4622],{"class":2413},[59,4787,1754],{"class":2455},[59,4789,4627],{"class":2413},[59,4791,83],{"class":2455},[59,4793,4794],{"class":2413},"], ",[59,4796,2553],{"class":2405},[59,4798,83],{"class":2455},[59,4800,2480],{"class":2413},[2390,4802,4804],{"className":2392,"code":4803,"language":2394,"meta":2395,"style":2395},"#@save\nclass MultiHeadAttention(nn.Module):\n    \"\"\"多头注意力\"\"\"\n    def __init__(self, key_size, query_size, value_size, num_hiddens,\n                 num_heads, dropout, bias=False, **kwargs):\n        super(MultiHeadAttention, self).__init__(**kwargs)\n        self.num_heads = num_heads\n        self.attention = d2l.DotProductAttention(dropout)\n        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n\n    def forward(self, queries, keys, values, valid_lens):\n        # queries，keys，values的形状:\n        # (batch_size，查询或者“键－值”对的个数，num_hiddens)\n        # valid_lens　的形状:\n        # (batch_size，)或(batch_size，查询的个数)\n        # 经过变换后，输出的queries，keys，values 的形状:\n        # (batch_size*num_heads，查询或者“键－值”对的个数，num_hiddens/num_heads)\n        queries = transpose_qkv(self.W_q(queries), self.num_heads)\n        keys = transpose_qkv(self.W_k(keys), self.num_heads)\n        values = transpose_qkv(self.W_v(values), self.num_heads)\n\n        if valid_lens is not None:\n            # 在轴0，将第一项（标量或者矢量）复制num_heads次，\n            # 然后如此复制第二项，然后诸如此类。\n            valid_lens = torch.repeat_interleave(\n                valid_lens, repeats=self.num_heads, dim=0)\n\n        # output的形状:(batch_size*num_heads，查询的个数， num_hiddens/num_heads)\n        output = self.attention(queries, keys, values, valid_lens)\n\n        # output_concat的形状:(batch_size，查询的个数，num_hiddens)\n        output_concat = transpose_output(output, self.num_heads)\n        return self.W_o(output_concat)\n",[2397,4805,4806,4811,4828,4833,4865,4892,4911,4923,4935,4954,4972,4990,5008,5012,5040,5045,5050,5055,5061,5067,5073,5094,5113,5132,5137,5153,5159,5165,5175,5199,5204,5210,5224,5229,5235,5250],{"__ignoreMap":2395},[59,4807,4808],{"class":2401,"line":2402},[59,4809,4810],{"class":2439},"#@save\n",[59,4812,4813,4815,4818,4820,4822,4824,4826],{"class":2401,"line":2429},[59,4814,2630],{"class":2405},[59,4816,4817],{"class":2633}," MultiHeadAttention",[59,4819,798],{"class":2413},[59,4821,2640],{"class":2639},[59,4823,864],{"class":2413},[59,4825,2645],{"class":2639},[59,4827,2426],{"class":2413},[59,4829,4830],{"class":2401,"line":2436},[59,4831,4832],{"class":2432},"    \"\"\"多头注意力\"\"\"\n",[59,4834,4835,4837,4839,4841,4843,4845,4848,4850,4853,4855,4858,4860,4863],{"class":2401,"line":2443},[59,4836,2657],{"class":2405},[59,4838,2660],{"class":2455},[59,4840,798],{"class":2413},[59,4842,2666],{"class":2665},[59,4844,2420],{"class":2413},[59,4846,4847],{"class":2416},"key_size",[59,4849,2420],{"class":2413},[59,4851,4852],{"class":2416},"query_size",[59,4854,2420],{"class":2413},[59,4856,4857],{"class":2416},"value_size",[59,4859,2420],{"class":2413},[59,4861,4862],{"class":2416},"num_hiddens",[59,4864,3826],{"class":2413},[59,4866,4867,4870,4872,4874,4876,4879,4881,4884,4886,4888,4890],{"class":2401,"line":2462},[59,4868,4869],{"class":2416},"                 num_heads",[59,4871,2420],{"class":2413},[59,4873,2671],{"class":2416},[59,4875,2420],{"class":2413},[59,4877,4878],{"class":2416},"bias",[59,4880,815],{"class":2405},[59,4882,4883],{"class":2455},"False",[59,4885,2420],{"class":2413},[59,4887,2676],{"class":2405},[59,4889,2679],{"class":2416},[59,4891,2426],{"class":2413},[59,4893,4894,4896,4899,4901,4903,4905,4907,4909],{"class":2401,"line":2483},[59,4895,2687],{"class":2686},[59,4897,4898],{"class":2413},"(MultiHeadAttention, ",[59,4900,2666],{"class":2693},[59,4902,2696],{"class":2413},[59,4904,2699],{"class":2455},[59,4906,798],{"class":2413},[59,4908,2676],{"class":2405},[59,4910,2706],{"class":2413},[59,4912,4913,4915,4918,4920],{"class":2401,"line":2491},[59,4914,2711],{"class":2693},[59,4916,4917],{"class":2413},".num_heads ",[59,4919,815],{"class":2405},[59,4921,4922],{"class":2413}," num_heads\n",[59,4924,4925,4927,4930,4932],{"class":2401,"line":2502},[59,4926,2711],{"class":2693},[59,4928,4929],{"class":2413},".attention ",[59,4931,815],{"class":2405},[59,4933,4934],{"class":2413}," d2l.DotProductAttention(dropout)\n",[59,4936,4937,4939,4942,4944,4947,4949,4951],{"class":2401,"line":2519},[59,4938,2711],{"class":2693},[59,4940,4941],{"class":2413},".W_q ",[59,4943,815],{"class":2405},[59,4945,4946],{"class":2413}," nn.Linear(query_size, num_hiddens, ",[59,4948,4878],{"class":2471},[59,4950,815],{"class":2405},[59,4952,4953],{"class":2413},"bias)\n",[59,4955,4956,4958,4961,4963,4966,4968,4970],{"class":2401,"line":2535},[59,4957,2711],{"class":2693},[59,4959,4960],{"class":2413},".W_k ",[59,4962,815],{"class":2405},[59,4964,4965],{"class":2413}," nn.Linear(key_size, num_hiddens, ",[59,4967,4878],{"class":2471},[59,4969,815],{"class":2405},[59,4971,4953],{"class":2413},[59,4973,4974,4976,4979,4981,4984,4986,4988],{"class":2401,"line":2543},[59,4975,2711],{"class":2693},[59,4977,4978],{"class":2413},".W_v ",[59,4980,815],{"class":2405},[59,4982,4983],{"class":2413}," nn.Linear(value_size, num_hiddens, ",[59,4985,4878],{"class":2471},[59,4987,815],{"class":2405},[59,4989,4953],{"class":2413},[59,4991,4992,4994,4997,4999,5002,5004,5006],{"class":2401,"line":2560},[59,4993,2711],{"class":2693},[59,4995,4996],{"class":2413},".W_o ",[59,4998,815],{"class":2405},[59,5000,5001],{"class":2413}," nn.Linear(num_hiddens, num_hiddens, ",[59,5003,4878],{"class":2471},[59,5005,815],{"class":2405},[59,5007,4953],{"class":2413},[59,5009,5010],{"class":2401,"line":2566},[59,5011,2725],{"emptyLinePlaceholder":2724},[59,5013,5014,5016,5018,5020,5022,5024,5026,5028,5030,5032,5034,5036,5038],{"class":2401,"line":2591},[59,5015,2657],{"class":2405},[59,5017,2752],{"class":2409},[59,5019,798],{"class":2413},[59,5021,2666],{"class":2665},[59,5023,2420],{"class":2413},[59,5025,2761],{"class":2416},[59,5027,2420],{"class":2413},[59,5029,2766],{"class":2416},[59,5031,2420],{"class":2413},[59,5033,2771],{"class":2416},[59,5035,2420],{"class":2413},[59,5037,2423],{"class":2416},[59,5039,2426],{"class":2413},[59,5041,5042],{"class":2401,"line":2604},[59,5043,5044],{"class":2439},"        # queries，keys，values的形状:\n",[59,5046,5047],{"class":2401,"line":2845},[59,5048,5049],{"class":2439},"        # (batch_size，查询或者“键－值”对的个数，num_hiddens)\n",[59,5051,5052],{"class":2401,"line":4781},[59,5053,5054],{"class":2439},"        # valid_lens　的形状:\n",[59,5056,5058],{"class":2401,"line":5057},18,[59,5059,5060],{"class":2439},"        # (batch_size，)或(batch_size，查询的个数)\n",[59,5062,5064],{"class":2401,"line":5063},19,[59,5065,5066],{"class":2439},"        # 经过变换后，输出的queries，keys，values 的形状:\n",[59,5068,5070],{"class":2401,"line":5069},20,[59,5071,5072],{"class":2439},"        # (batch_size*num_heads，查询或者“键－值”对的个数，num_hiddens/num_heads)\n",[59,5074,5076,5079,5081,5084,5086,5089,5091],{"class":2401,"line":5075},21,[59,5077,5078],{"class":2413},"        queries ",[59,5080,815],{"class":2405},[59,5082,5083],{"class":2413}," transpose_qkv(",[59,5085,2666],{"class":2693},[59,5087,5088],{"class":2413},".W_q(queries), ",[59,5090,2666],{"class":2693},[59,5092,5093],{"class":2413},".num_heads)\n",[59,5095,5097,5100,5102,5104,5106,5109,5111],{"class":2401,"line":5096},22,[59,5098,5099],{"class":2413},"        keys ",[59,5101,815],{"class":2405},[59,5103,5083],{"class":2413},[59,5105,2666],{"class":2693},[59,5107,5108],{"class":2413},".W_k(keys), ",[59,5110,2666],{"class":2693},[59,5112,5093],{"class":2413},[59,5114,5116,5119,5121,5123,5125,5128,5130],{"class":2401,"line":5115},23,[59,5117,5118],{"class":2413},"        values ",[59,5120,815],{"class":2405},[59,5122,5083],{"class":2413},[59,5124,2666],{"class":2693},[59,5126,5127],{"class":2413},".W_v(values), ",[59,5129,2666],{"class":2693},[59,5131,5093],{"class":2413},[59,5133,5135],{"class":2401,"line":5134},24,[59,5136,2725],{"emptyLinePlaceholder":2724},[59,5138,5140,5142,5144,5146,5149,5151],{"class":2401,"line":5139},25,[59,5141,2505],{"class":2405},[59,5143,2449],{"class":2413},[59,5145,2452],{"class":2405},[59,5147,5148],{"class":2405}," not",[59,5150,2456],{"class":2455},[59,5152,2459],{"class":2413},[59,5154,5156],{"class":2401,"line":5155},26,[59,5157,5158],{"class":2439},"            # 在轴0，将第一项（标量或者矢量）复制num_heads次，\n",[59,5160,5162],{"class":2401,"line":5161},27,[59,5163,5164],{"class":2439},"            # 然后如此复制第二项，然后诸如此类。\n",[59,5166,5168,5170,5172],{"class":2401,"line":5167},28,[59,5169,2522],{"class":2413},[59,5171,815],{"class":2405},[59,5173,5174],{"class":2413}," torch.repeat_interleave(\n",[59,5176,5178,5181,5184,5186,5188,5191,5193,5195,5197],{"class":2401,"line":5177},29,[59,5179,5180],{"class":2413},"                valid_lens, ",[59,5182,5183],{"class":2471},"repeats",[59,5185,815],{"class":2405},[59,5187,2666],{"class":2693},[59,5189,5190],{"class":2413},".num_heads, ",[59,5192,2472],{"class":2471},[59,5194,815],{"class":2405},[59,5196,1754],{"class":2455},[59,5198,2480],{"class":2413},[59,5200,5202],{"class":2401,"line":5201},30,[59,5203,2725],{"emptyLinePlaceholder":2724},[59,5205,5207],{"class":2401,"line":5206},31,[59,5208,5209],{"class":2439},"        # output的形状:(batch_size*num_heads，查询的个数， num_hiddens/num_heads)\n",[59,5211,5213,5216,5218,5221],{"class":2401,"line":5212},32,[59,5214,5215],{"class":2413},"        output ",[59,5217,815],{"class":2405},[59,5219,5220],{"class":2693}," self",[59,5222,5223],{"class":2413},".attention(queries, keys, values, valid_lens)\n",[59,5225,5227],{"class":2401,"line":5226},33,[59,5228,2725],{"emptyLinePlaceholder":2724},[59,5230,5232],{"class":2401,"line":5231},34,[59,5233,5234],{"class":2439},"        # output_concat的形状:(batch_size，查询的个数，num_hiddens)\n",[59,5236,5238,5241,5243,5246,5248],{"class":2401,"line":5237},35,[59,5239,5240],{"class":2413},"        output_concat ",[59,5242,815],{"class":2405},[59,5244,5245],{"class":2413}," transpose_output(output, ",[59,5247,2666],{"class":2693},[59,5249,5093],{"class":2413},[59,5251,5253,5255,5257],{"class":2401,"line":5252},36,[59,5254,2465],{"class":2405},[59,5256,5220],{"class":2693},[59,5258,5259],{"class":2413},".W_o(output_concat)\n",[11,5261,5262],{"id":5262},"位置编码",[15,5264,5265,5266,5357],{},"在处理 token 序列时 RNN 采用了按顺序逐个处理 token 的，但注意力机制为了并行计算放弃了顺序处理，这会导致其失去位置信息。为了使模型能够使用到位置信息，我们需要在 token 中注入一些和位置相关的信息，我们使用位置编码来注入绝对的或相对的位置信息。位置编码与嵌入的维度相同，均为 ",[59,5267,5269,5295],{"className":5268},[62],[59,5270,5272],{"className":5271},[66],[68,5273,5274],{"xmlns":70},[72,5275,5276,5292],{},[75,5277,5278,5284,5286,5288,5290],{},[88,5279,5280,5282],{},[91,5281,93],{},[91,5283,366],{},[91,5285,792],{},[91,5287,93],{},[91,5289,782],{},[91,5291,2903],{},[98,5293,5294],{"encoding":100},"d_model",[59,5296,5298],{"className":5297,"ariaHidden":106},[105],[59,5299,5301,5304,5344,5347,5350,5353],{"className":5300},[110],[59,5302],{"className":5303,"style":1229},[114],[59,5305,5307,5310],{"className":5306},[119],[59,5308,93],{"className":5309},[119,192],[59,5311,5313],{"className":5312},[196],[59,5314,5316,5336],{"className":5315},[131,132],[59,5317,5319,5333],{"className":5318},[136],[59,5320,5322],{"className":5321,"style":1322},[140],[59,5323,5324,5327],{"style":1251},[59,5325],{"className":5326,"style":540},[148],[59,5328,5330],{"className":5329},[153,154,155,156],[59,5331,366],{"className":5332},[119,192,156],[59,5334,227],{"className":5335},[226],[59,5337,5339],{"className":5338},[136],[59,5340,5342],{"className":5341,"style":1270},[140],[59,5343],{},[59,5345,792],{"className":5346},[119,192],[59,5348,93],{"className":5349},[119,192],[59,5351,782],{"className":5352},[119,192],[59,5354,2903],{"className":5355,"style":5356},[119,192],"margin-right:0.01968em;","，因此两者可以相加。位置编码可以通过学习得到也可以直接固定得到。",[15,5359,5360],{},"在这里，我们使用的是基于正弦函数和余弦函数的位置编码。",[15,5362,5363,5364,5529,5530,5558,5559,5587,5588,5760,5761,5822,5823,5852,5853,5890,5891,5947],{},"假设输入表示 ",[59,5365,5367,5410],{"className":5366},[62],[59,5368,5370],{"className":5369},[66],[68,5371,5372],{"xmlns":70},[72,5373,5374,5407],{},[75,5375,5376,5379,5381],{},[91,5377,2417],{"mathvariant":5378},"bold",[461,5380,463],{},[465,5382,5383,5385],{},[91,5384,470],{"mathvariant":469},[75,5386,5387,5389,5391],{},[91,5388,335],{},[461,5390,477],{},[88,5392,5393,5395],{},[91,5394,93],{},[75,5396,5397,5399,5401,5403,5405],{},[91,5398,366],{},[91,5400,792],{},[91,5402,93],{},[91,5404,782],{},[91,5406,2903],{},[98,5408,5409],{"encoding":100},"\\mathbf{X} \\in \\mathbb{R}^{n \\times d_{model}}",[59,5411,5413,5433],{"className":5412,"ariaHidden":106},[105],[59,5414,5416,5420,5424,5427,5430],{"className":5415},[110],[59,5417],{"className":5418,"style":5419},[114],"height:0.7252em;vertical-align:-0.0391em;",[59,5421,2417],{"className":5422},[119,5423],"mathbf",[59,5425],{"className":5426,"style":500},[499],[59,5428,463],{"className":5429},[504],[59,5431],{"className":5432,"style":500},[499],[59,5434,5436,5439],{"className":5435},[110],[59,5437],{"className":5438,"style":514},[114],[59,5440,5442,5445],{"className":5441},[119],[59,5443,470],{"className":5444},[119,521],[59,5446,5448],{"className":5447},[196],[59,5449,5451],{"className":5450},[131],[59,5452,5454],{"className":5453},[136],[59,5455,5457],{"className":5456,"style":514},[140],[59,5458,5459,5462],{"style":536},[59,5460],{"className":5461,"style":540},[148],[59,5463,5465],{"className":5464},[153,154,155,156],[59,5466,5468,5471,5474],{"className":5467},[119,156],[59,5469,335],{"className":5470},[119,192,156],[59,5472,477],{"className":5473},[553,156],[59,5475,5477,5480],{"className":5476},[119,156],[59,5478,93],{"className":5479},[119,192,156],[59,5481,5483],{"className":5482},[196],[59,5484,5486,5521],{"className":5485},[131,132],[59,5487,5489,5518],{"className":5488},[136],[59,5490,5492],{"className":5491,"style":206},[140],[59,5493,5494,5497],{"style":209},[59,5495],{"className":5496,"style":213},[148],[59,5498,5500],{"className":5499},[153,217,218,156],[59,5501,5503,5506,5509,5512,5515],{"className":5502},[119,156],[59,5504,366],{"className":5505},[119,192,156],[59,5507,792],{"className":5508},[119,192,156],[59,5510,93],{"className":5511},[119,192,156],[59,5513,782],{"className":5514},[119,192,156],[59,5516,2903],{"className":5517,"style":5356},[119,192,156],[59,5519,227],{"className":5520},[226],[59,5522,5524],{"className":5523},[136],[59,5525,5527],{"className":5526,"style":234},[140],[59,5528],{}," 包含一个序列中 ",[59,5531,5533,5546],{"className":5532},[62],[59,5534,5536],{"className":5535},[66],[68,5537,5538],{"xmlns":70},[72,5539,5540,5544],{},[75,5541,5542],{},[91,5543,335],{},[98,5545,335],{"encoding":100},[59,5547,5549],{"className":5548,"ariaHidden":106},[105],[59,5550,5552,5555],{"className":5551},[110],[59,5553],{"className":5554,"style":347},[114],[59,5556,335],{"className":5557},[119,192]," 个词元的 ",[59,5560,5562,5575],{"className":5561},[62],[59,5563,5565],{"className":5564},[66],[68,5566,5567],{"xmlns":70},[72,5568,5569,5573],{},[75,5570,5571],{},[91,5572,93],{},[98,5574,93],{"encoding":100},[59,5576,5578],{"className":5577,"ariaHidden":106},[105],[59,5579,5581,5584],{"className":5580},[110],[59,5582],{"className":5583,"style":407},[114],[59,5585,93],{"className":5586},[119,192]," 维嵌入表示。位置编码使用相同形状的位置嵌入矩阵 ",[59,5589,5591,5639],{"className":5590},[62],[59,5592,5594],{"className":5593},[66],[68,5595,5596],{"xmlns":70},[72,5597,5598,5636],{},[75,5599,5600,5608,5610],{},[75,5601,5602,5605],{},[91,5603,5604],{"mathvariant":5378},"P",[91,5606,5607],{"mathvariant":5378},"E",[461,5609,463],{},[465,5611,5612,5614],{},[91,5613,470],{"mathvariant":469},[75,5615,5616,5618,5620],{},[91,5617,335],{},[461,5619,477],{},[88,5621,5622,5624],{},[91,5623,93],{},[75,5625,5626,5628,5630,5632,5634],{},[91,5627,366],{},[91,5629,792],{},[91,5631,93],{},[91,5633,782],{},[91,5635,2903],{},[98,5637,5638],{"encoding":100},"\\mathbf{PE} \\in \\mathbb{R}^{n \\times d_{model}}",[59,5640,5642,5664],{"className":5641,"ariaHidden":106},[105],[59,5643,5645,5648,5655,5658,5661],{"className":5644},[110],[59,5646],{"className":5647,"style":5419},[114],[59,5649,5651],{"className":5650},[119],[59,5652,5654],{"className":5653},[119,5423],"PE",[59,5656],{"className":5657,"style":500},[499],[59,5659,463],{"className":5660},[504],[59,5662],{"className":5663,"style":500},[499],[59,5665,5667,5670],{"className":5666},[110],[59,5668],{"className":5669,"style":514},[114],[59,5671,5673,5676],{"className":5672},[119],[59,5674,470],{"className":5675},[119,521],[59,5677,5679],{"className":5678},[196],[59,5680,5682],{"className":5681},[131],[59,5683,5685],{"className":5684},[136],[59,5686,5688],{"className":5687,"style":514},[140],[59,5689,5690,5693],{"style":536},[59,5691],{"className":5692,"style":540},[148],[59,5694,5696],{"className":5695},[153,154,155,156],[59,5697,5699,5702,5705],{"className":5698},[119,156],[59,5700,335],{"className":5701},[119,192,156],[59,5703,477],{"className":5704},[553,156],[59,5706,5708,5711],{"className":5707},[119,156],[59,5709,93],{"className":5710},[119,192,156],[59,5712,5714],{"className":5713},[196],[59,5715,5717,5752],{"className":5716},[131,132],[59,5718,5720,5749],{"className":5719},[136],[59,5721,5723],{"className":5722,"style":206},[140],[59,5724,5725,5728],{"style":209},[59,5726],{"className":5727,"style":213},[148],[59,5729,5731],{"className":5730},[153,217,218,156],[59,5732,5734,5737,5740,5743,5746],{"className":5733},[119,156],[59,5735,366],{"className":5736},[119,192,156],[59,5738,792],{"className":5739},[119,192,156],[59,5741,93],{"className":5742},[119,192,156],[59,5744,782],{"className":5745},[119,192,156],[59,5747,2903],{"className":5748,"style":5356},[119,192,156],[59,5750,227],{"className":5751},[226],[59,5753,5755],{"className":5754},[136],[59,5756,5758],{"className":5757,"style":234},[140],[59,5759],{}," 输出 ",[59,5762,5764,5787],{"className":5763},[62],[59,5765,5767],{"className":5766},[66],[68,5768,5769],{"xmlns":70},[72,5770,5771,5784],{},[75,5772,5773,5775,5778],{},[91,5774,2417],{"mathvariant":5378},[461,5776,5777],{},"+",[75,5779,5780,5782],{},[91,5781,5604],{"mathvariant":5378},[91,5783,5607],{"mathvariant":5378},[98,5785,5786],{"encoding":100},"\\mathbf{X} + \\mathbf{PE}",[59,5788,5790,5809],{"className":5789,"ariaHidden":106},[105],[59,5791,5793,5797,5800,5803,5806],{"className":5792},[110],[59,5794],{"className":5795,"style":5796},[114],"height:0.7694em;vertical-align:-0.0833em;",[59,5798,2417],{"className":5799},[119,5423],[59,5801],{"className":5802,"style":1872},[499],[59,5804,5777],{"className":5805},[553],[59,5807],{"className":5808,"style":1872},[499],[59,5810,5812,5816],{"className":5811},[110],[59,5813],{"className":5814,"style":5815},[114],"height:0.6861em;",[59,5817,5819],{"className":5818},[119],[59,5820,5654],{"className":5821},[119,5423],"，矩阵第 ",[59,5824,5826,5839],{"className":5825},[62],[59,5827,5829],{"className":5828},[66],[68,5830,5831],{"xmlns":70},[72,5832,5833,5837],{},[75,5834,5835],{},[91,5836,789],{},[98,5838,789],{"encoding":100},[59,5840,5842],{"className":5841,"ariaHidden":106},[105],[59,5843,5845,5849],{"className":5844},[110],[59,5846],{"className":5847,"style":5848},[114],"height:0.6595em;",[59,5850,789],{"className":5851},[119,192]," 行、第 ",[59,5854,5856,5873],{"className":5855},[62],[59,5857,5859],{"className":5858},[66],[68,5860,5861],{"xmlns":70},[72,5862,5863,5870],{},[75,5864,5865,5867],{},[81,5866,2821],{},[91,5868,5869],{},"j",[98,5871,5872],{"encoding":100},"2j",[59,5874,5876],{"className":5875,"ariaHidden":106},[105],[59,5877,5879,5883,5886],{"className":5878},[110],[59,5880],{"className":5881,"style":5882},[114],"height:0.854em;vertical-align:-0.1944em;",[59,5884,2821],{"className":5885},[119],[59,5887,5869],{"className":5888,"style":5889},[119,192],"margin-right:0.05724em;"," 列和 ",[59,5892,5894,5914],{"className":5893},[62],[59,5895,5897],{"className":5896},[66],[68,5898,5899],{"xmlns":70},[72,5900,5901,5911],{},[75,5902,5903,5905,5907,5909],{},[81,5904,2821],{},[91,5906,5869],{},[461,5908,5777],{},[81,5910,83],{},[98,5912,5913],{"encoding":100},"2j+1",[59,5915,5917,5938],{"className":5916,"ariaHidden":106},[105],[59,5918,5920,5923,5926,5929,5932,5935],{"className":5919},[110],[59,5921],{"className":5922,"style":5882},[114],[59,5924,2821],{"className":5925},[119],[59,5927,5869],{"className":5928,"style":5889},[119,192],[59,5930],{"className":5931,"style":1872},[499],[59,5933,5777],{"className":5934},[553],[59,5936],{"className":5937,"style":1872},[499],[59,5939,5941,5944],{"className":5940},[110],[59,5942],{"className":5943,"style":1766},[114],[59,5945,83],{"className":5946},[119]," 列上的元素为：",[15,5949,5950],{},[59,5951,5953,6031],{"className":5952},[62],[59,5954,5956],{"className":5955},[66],[68,5957,5958],{"xmlns":70},[72,5959,5960,6028],{},[75,5961,5962,5964,5982,5984,5987,5989],{},[91,5963,5604],{},[88,5965,5966,5968],{},[91,5967,5607],{},[75,5969,5970,5972,5974,5976,5978,5980],{},[461,5971,798],{"stretchy":797},[91,5973,789],{},[461,5975,803],{"separator":106},[81,5977,2821],{},[91,5979,5869],{},[461,5981,812],{"stretchy":797},[461,5983,815],{},[91,5985,5986],{},"sin",[461,5988,822],{},[75,5990,5991,5993,6026],{},[461,5992,798],{"fence":106},[78,5994,5995,5997],{},[91,5996,789],{},[465,5998,5999,6002],{},[81,6000,6001],{},"10000",[75,6003,6004,6006,6008,6010],{},[81,6005,2821],{},[91,6007,5869],{},[91,6009,2827],{"mathvariant":818},[88,6011,6012,6014],{},[91,6013,93],{},[75,6015,6016,6018,6020,6022,6024],{},[91,6017,366],{},[91,6019,792],{},[91,6021,93],{},[91,6023,782],{},[91,6025,2903],{},[461,6027,812],{"fence":106},[98,6029,6030],{"encoding":100}," PE_{(i,2j)} = \\sin\\left(\\frac{i}{10000^{2j/d_{model}}}\\right)",[59,6032,6034,6114],{"className":6033,"ariaHidden":106},[105],[59,6035,6037,6041,6044,6105,6108,6111],{"className":6036},[110],[59,6038],{"className":6039,"style":6040},[114],"height:1.0385em;vertical-align:-0.3552em;",[59,6042,5604],{"className":6043,"style":3185},[119,192],[59,6045,6047,6051],{"className":6046},[119],[59,6048,5607],{"className":6049,"style":6050},[119,192],"margin-right:0.05764em;",[59,6052,6054],{"className":6053},[196],[59,6055,6057,6096],{"className":6056},[131,132],[59,6058,6060,6093],{"className":6059},[136],[59,6061,6063],{"className":6062,"style":206},[140],[59,6064,6066,6069],{"style":6065},"top:-2.5198em;margin-left:-0.0576em;margin-right:0.05em;",[59,6067],{"className":6068,"style":540},[148],[59,6070,6072],{"className":6071},[153,154,155,156],[59,6073,6075,6078,6081,6084,6087,6090],{"className":6074},[119,156],[59,6076,798],{"className":6077},[123,156],[59,6079,789],{"className":6080},[119,192,156],[59,6082,803],{"className":6083},[912,156],[59,6085,2821],{"className":6086},[119,156],[59,6088,5869],{"className":6089,"style":5889},[119,192,156],[59,6091,812],{"className":6092},[313,156],[59,6094,227],{"className":6095},[226],[59,6097,6099],{"className":6098},[136],[59,6100,6103],{"className":6101,"style":6102},[140],"height:0.3552em;",[59,6104],{},[59,6106],{"className":6107,"style":500},[499],[59,6109,815],{"className":6110},[504],[59,6112],{"className":6113,"style":500},[499],[59,6115,6117,6120,6123,6126],{"className":6116},[110],[59,6118],{"className":6119,"style":947},[114],[59,6121,5986],{"className":6122},[951],[59,6124],{"className":6125,"style":916},[499],[59,6127,6129,6135,6303],{"className":6128},[962],[59,6130,6132],{"className":6131,"style":967},[123,966],[59,6133,798],{"className":6134},[971,972],[59,6136,6138,6141,6300],{"className":6137},[119],[59,6139],{"className":6140},[123,124],[59,6142,6144],{"className":6143},[78],[59,6145,6147,6291],{"className":6146},[131,132],[59,6148,6150,6288],{"className":6149},[136],[59,6151,6154,6266,6274],{"className":6152,"style":6153},[140],"height:0.8557em;",[59,6155,6157,6160],{"style":6156},"top:-2.5648em;",[59,6158],{"className":6159,"style":149},[148],[59,6161,6163],{"className":6162},[153,154,155,156],[59,6164,6166,6170],{"className":6165},[119,156],[59,6167,6169],{"className":6168},[119,156],"1000",[59,6171,6173,6176],{"className":6172},[119,156],[59,6174,1754],{"className":6175},[119,156],[59,6177,6179],{"className":6178},[196],[59,6180,6182],{"className":6181},[131],[59,6183,6185],{"className":6184},[136],[59,6186,6189],{"className":6187,"style":6188},[140],"height:0.8932em;",[59,6190,6192,6196],{"style":6191},"top:-2.8932em;margin-right:0.0714em;",[59,6193],{"className":6194,"style":6195},[148],"height:2.5357em;",[59,6197,6199],{"className":6198},[153,217,218,156],[59,6200,6202,6205,6208,6211],{"className":6201},[119,156],[59,6203,2821],{"className":6204},[119,156],[59,6206,5869],{"className":6207,"style":5889},[119,192,156],[59,6209,2827],{"className":6210},[119,156],[59,6212,6214,6217],{"className":6213},[119,156],[59,6215,93],{"className":6216},[119,192,156],[59,6218,6220],{"className":6219},[196],[59,6221,6223,6257],{"className":6222},[131,132],[59,6224,6226,6254],{"className":6225},[136],[59,6227,6229],{"className":6228,"style":206},[140],[59,6230,6232,6236],{"style":6231},"top:-2.3448em;margin-left:0em;margin-right:0.1em;",[59,6233],{"className":6234,"style":6235},[148],"height:2.6944em;",[59,6237,6239,6242,6245,6248,6251],{"className":6238},[119,156],[59,6240,366],{"className":6241},[119,192,156],[59,6243,792],{"className":6244},[119,192,156],[59,6246,93],{"className":6247},[119,192,156],[59,6249,782],{"className":6250},[119,192,156],[59,6252,2903],{"className":6253,"style":5356},[119,192,156],[59,6255,227],{"className":6256},[226],[59,6258,6260],{"className":6259},[136],[59,6261,6264],{"className":6262,"style":6263},[140],"height:0.3496em;",[59,6265],{},[59,6267,6268,6271],{"style":274},[59,6269],{"className":6270,"style":149},[148],[59,6272],{"className":6273,"style":282},[281],[59,6275,6276,6279],{"style":285},[59,6277],{"className":6278,"style":149},[148],[59,6280,6282],{"className":6281},[153,154,155,156],[59,6283,6285],{"className":6284},[119,156],[59,6286,789],{"className":6287},[119,192,156],[59,6289,227],{"className":6290},[226],[59,6292,6294],{"className":6293},[136],[59,6295,6298],{"className":6296,"style":6297},[140],"height:0.4352em;",[59,6299],{},[59,6301],{"className":6302},[313,124],[59,6304,6306],{"className":6305,"style":967},[313,966],[59,6307,812],{"className":6308},[971,972],[15,6310,6311],{},[59,6312,6314,6395],{"className":6313},[62],[59,6315,6317],{"className":6316},[66],[68,6318,6319],{"xmlns":70},[72,6320,6321,6392],{},[75,6322,6323,6325,6347,6349,6352,6354],{},[91,6324,5604],{},[88,6326,6327,6329],{},[91,6328,5607],{},[75,6330,6331,6333,6335,6337,6339,6341,6343,6345],{},[461,6332,798],{"stretchy":797},[91,6334,789],{},[461,6336,803],{"separator":106},[81,6338,2821],{},[91,6340,5869],{},[461,6342,5777],{},[81,6344,83],{},[461,6346,812],{"stretchy":797},[461,6348,815],{},[91,6350,6351],{},"cos",[461,6353,822],{},[75,6355,6356,6358,6390],{},[461,6357,798],{"fence":106},[78,6359,6360,6362],{},[91,6361,789],{},[465,6363,6364,6366],{},[81,6365,6001],{},[75,6367,6368,6370,6372,6374],{},[81,6369,2821],{},[91,6371,5869],{},[91,6373,2827],{"mathvariant":818},[88,6375,6376,6378],{},[91,6377,93],{},[75,6379,6380,6382,6384,6386,6388],{},[91,6381,366],{},[91,6383,792],{},[91,6385,93],{},[91,6387,782],{},[91,6389,2903],{},[461,6391,812],{"fence":106},[98,6393,6394],{"encoding":100}," PE_{(i,2j+1)} = \\cos\\left(\\frac{i}{10000^{2j/d_{model}}}\\right)",[59,6396,6398,6480],{"className":6397,"ariaHidden":106},[105],[59,6399,6401,6404,6407,6471,6474,6477],{"className":6400},[110],[59,6402],{"className":6403,"style":6040},[114],[59,6405,5604],{"className":6406,"style":3185},[119,192],[59,6408,6410,6413],{"className":6409},[119],[59,6411,5607],{"className":6412,"style":6050},[119,192],[59,6414,6416],{"className":6415},[196],[59,6417,6419,6463],{"className":6418},[131,132],[59,6420,6422,6460],{"className":6421},[136],[59,6423,6425],{"className":6424,"style":206},[140],[59,6426,6427,6430],{"style":6065},[59,6428],{"className":6429,"style":540},[148],[59,6431,6433],{"className":6432},[153,154,155,156],[59,6434,6436,6439,6442,6445,6448,6451,6454,6457],{"className":6435},[119,156],[59,6437,798],{"className":6438},[123,156],[59,6440,789],{"className":6441},[119,192,156],[59,6443,803],{"className":6444},[912,156],[59,6446,2821],{"className":6447},[119,156],[59,6449,5869],{"className":6450,"style":5889},[119,192,156],[59,6452,5777],{"className":6453},[553,156],[59,6455,83],{"className":6456},[119,156],[59,6458,812],{"className":6459},[313,156],[59,6461,227],{"className":6462},[226],[59,6464,6466],{"className":6465},[136],[59,6467,6469],{"className":6468,"style":6102},[140],[59,6470],{},[59,6472],{"className":6473,"style":500},[499],[59,6475,815],{"className":6476},[504],[59,6478],{"className":6479,"style":500},[499],[59,6481,6483,6486,6489,6492],{"className":6482},[110],[59,6484],{"className":6485,"style":947},[114],[59,6487,6351],{"className":6488},[951],[59,6490],{"className":6491,"style":916},[499],[59,6493,6495,6501,6659],{"className":6494},[962],[59,6496,6498],{"className":6497,"style":967},[123,966],[59,6499,798],{"className":6500},[971,972],[59,6502,6504,6507,6656],{"className":6503},[119],[59,6505],{"className":6506},[123,124],[59,6508,6510],{"className":6509},[78],[59,6511,6513,6648],{"className":6512},[131,132],[59,6514,6516,6645],{"className":6515},[136],[59,6517,6519,6623,6631],{"className":6518,"style":6153},[140],[59,6520,6521,6524],{"style":6156},[59,6522],{"className":6523,"style":149},[148],[59,6525,6527],{"className":6526},[153,154,155,156],[59,6528,6530,6533],{"className":6529},[119,156],[59,6531,6169],{"className":6532},[119,156],[59,6534,6536,6539],{"className":6535},[119,156],[59,6537,1754],{"className":6538},[119,156],[59,6540,6542],{"className":6541},[196],[59,6543,6545],{"className":6544},[131],[59,6546,6548],{"className":6547},[136],[59,6549,6551],{"className":6550,"style":6188},[140],[59,6552,6553,6556],{"style":6191},[59,6554],{"className":6555,"style":6195},[148],[59,6557,6559],{"className":6558},[153,217,218,156],[59,6560,6562,6565,6568,6571],{"className":6561},[119,156],[59,6563,2821],{"className":6564},[119,156],[59,6566,5869],{"className":6567,"style":5889},[119,192,156],[59,6569,2827],{"className":6570},[119,156],[59,6572,6574,6577],{"className":6573},[119,156],[59,6575,93],{"className":6576},[119,192,156],[59,6578,6580],{"className":6579},[196],[59,6581,6583,6615],{"className":6582},[131,132],[59,6584,6586,6612],{"className":6585},[136],[59,6587,6589],{"className":6588,"style":206},[140],[59,6590,6591,6594],{"style":6231},[59,6592],{"className":6593,"style":6235},[148],[59,6595,6597,6600,6603,6606,6609],{"className":6596},[119,156],[59,6598,366],{"className":6599},[119,192,156],[59,6601,792],{"className":6602},[119,192,156],[59,6604,93],{"className":6605},[119,192,156],[59,6607,782],{"className":6608},[119,192,156],[59,6610,2903],{"className":6611,"style":5356},[119,192,156],[59,6613,227],{"className":6614},[226],[59,6616,6618],{"className":6617},[136],[59,6619,6621],{"className":6620,"style":6263},[140],[59,6622],{},[59,6624,6625,6628],{"style":274},[59,6626],{"className":6627,"style":149},[148],[59,6629],{"className":6630,"style":282},[281],[59,6632,6633,6636],{"style":285},[59,6634],{"className":6635,"style":149},[148],[59,6637,6639],{"className":6638},[153,154,155,156],[59,6640,6642],{"className":6641},[119,156],[59,6643,789],{"className":6644},[119,192,156],[59,6646,227],{"className":6647},[226],[59,6649,6651],{"className":6650},[136],[59,6652,6654],{"className":6653,"style":6297},[140],[59,6655],{},[59,6657],{"className":6658},[313,124],[59,6660,6662],{"className":6661,"style":967},[313,966],[59,6663,812],{"className":6664},[971,972],[15,6666,6667,6668,6696,6697,6725,6726,6761,6762,6818,6819,6847,6848,6941,6942,7021],{},"其中，",[59,6669,6671,6684],{"className":6670},[62],[59,6672,6674],{"className":6673},[66],[68,6675,6676],{"xmlns":70},[72,6677,6678,6682],{},[75,6679,6680],{},[91,6681,789],{},[98,6683,789],{"encoding":100},[59,6685,6687],{"className":6686,"ariaHidden":106},[105],[59,6688,6690,6693],{"className":6689},[110],[59,6691],{"className":6692,"style":5848},[114],[59,6694,789],{"className":6695},[119,192]," 表示位置，",[59,6698,6700,6713],{"className":6699},[62],[59,6701,6703],{"className":6702},[66],[68,6704,6705],{"xmlns":70},[72,6706,6707,6711],{},[75,6708,6709],{},[91,6710,5869],{},[98,6712,5869],{"encoding":100},[59,6714,6716],{"className":6715,"ariaHidden":106},[105],[59,6717,6719,6722],{"className":6718},[110],[59,6720],{"className":6721,"style":5882},[114],[59,6723,5869],{"className":6724,"style":5889},[119,192]," 表示维度。也就是说，位置编码的每一个维度都对应一个正弦曲线。其波长从 ",[59,6727,6729,6746],{"className":6728},[62],[59,6730,6732],{"className":6731},[66],[68,6733,6734],{"xmlns":70},[72,6735,6736,6743],{},[75,6737,6738,6740],{},[81,6739,2821],{},[91,6741,6742],{},"π",[98,6744,6745],{"encoding":100},"2\\pi",[59,6747,6749],{"className":6748,"ariaHidden":106},[105],[59,6750,6752,6755,6758],{"className":6751},[110],[59,6753],{"className":6754,"style":1766},[114],[59,6756,2821],{"className":6757},[119],[59,6759,6742],{"className":6760,"style":441},[119,192]," 到 ",[59,6763,6765,6785],{"className":6764},[62],[59,6766,6768],{"className":6767},[66],[68,6769,6770],{"xmlns":70},[72,6771,6772,6782],{},[75,6773,6774,6776,6778,6780],{},[81,6775,6001],{},[461,6777,1816],{},[81,6779,2821],{},[91,6781,6742],{},[98,6783,6784],{"encoding":100},"10000 \\cdot 2\\pi",[59,6786,6788,6806],{"className":6787,"ariaHidden":106},[105],[59,6789,6791,6794,6797,6800,6803],{"className":6790},[110],[59,6792],{"className":6793,"style":1766},[114],[59,6795,6001],{"className":6796},[119],[59,6798],{"className":6799,"style":1872},[499],[59,6801,1816],{"className":6802},[553],[59,6804],{"className":6805,"style":1872},[499],[59,6807,6809,6812,6815],{"className":6808},[110],[59,6810],{"className":6811,"style":1766},[114],[59,6813,2821],{"className":6814},[119],[59,6816,6742],{"className":6817,"style":441},[119,192]," 按几何级数增长。我们选择这个函数，是因为我们假设它能够使模型更容易学习关注相对位置，因为对于任意固定的偏移量 ",[59,6820,6822,6835],{"className":6821},[62],[59,6823,6825],{"className":6824},[66],[68,6826,6827],{"xmlns":70},[72,6828,6829,6833],{},[75,6830,6831],{},[91,6832,96],{},[98,6834,96],{"encoding":100},[59,6836,6838],{"className":6837,"ariaHidden":106},[105],[59,6839,6841,6844],{"className":6840},[110],[59,6842],{"className":6843,"style":407},[114],[59,6845,96],{"className":6846,"style":222},[119,192],"，",[59,6849,6851,6877],{"className":6850},[62],[59,6852,6854],{"className":6853},[66],[68,6855,6856],{"xmlns":70},[72,6857,6858,6874],{},[75,6859,6860,6862],{},[91,6861,5604],{},[88,6863,6864,6866],{},[91,6865,5607],{},[75,6867,6868,6870,6872],{},[91,6869,789],{},[461,6871,5777],{},[91,6873,96],{},[98,6875,6876],{"encoding":100},"PE_{i+k}",[59,6878,6880],{"className":6879,"ariaHidden":106},[105],[59,6881,6883,6887,6890],{"className":6882},[110],[59,6884],{"className":6885,"style":6886},[114],"height:0.8917em;vertical-align:-0.2083em;",[59,6888,5604],{"className":6889,"style":3185},[119,192],[59,6891,6893,6896],{"className":6892},[119],[59,6894,5607],{"className":6895,"style":6050},[119,192],[59,6897,6899],{"className":6898},[196],[59,6900,6902,6932],{"className":6901},[131,132],[59,6903,6905,6929],{"className":6904},[136],[59,6906,6908],{"className":6907,"style":1248},[140],[59,6909,6911,6914],{"style":6910},"top:-2.55em;margin-left:-0.0576em;margin-right:0.05em;",[59,6912],{"className":6913,"style":540},[148],[59,6915,6917],{"className":6916},[153,154,155,156],[59,6918,6920,6923,6926],{"className":6919},[119,156],[59,6921,789],{"className":6922},[119,192,156],[59,6924,5777],{"className":6925},[553,156],[59,6927,96],{"className":6928,"style":222},[119,192,156],[59,6930,227],{"className":6931},[226],[59,6933,6935],{"className":6934},[136],[59,6936,6939],{"className":6937,"style":6938},[140],"height:0.2083em;",[59,6940],{}," 都可以表示为 ",[59,6943,6945,6965],{"className":6944},[62],[59,6946,6948],{"className":6947},[66],[68,6949,6950],{"xmlns":70},[72,6951,6952,6962],{},[75,6953,6954,6956],{},[91,6955,5604],{},[88,6957,6958,6960],{},[91,6959,5607],{},[91,6961,789],{},[98,6963,6964],{"encoding":100},"PE_{i}",[59,6966,6968],{"className":6967,"ariaHidden":106},[105],[59,6969,6971,6975,6978],{"className":6970},[110],[59,6972],{"className":6973,"style":6974},[114],"height:0.8333em;vertical-align:-0.15em;",[59,6976,5604],{"className":6977,"style":3185},[119,192],[59,6979,6981,6984],{"className":6980},[119],[59,6982,5607],{"className":6983,"style":6050},[119,192],[59,6985,6987],{"className":6986},[196],[59,6988,6990,7013],{"className":6989},[131,132],[59,6991,6993,7010],{"className":6992},[136],[59,6994,6996],{"className":6995,"style":2032},[140],[59,6997,6998,7001],{"style":6910},[59,6999],{"className":7000,"style":540},[148],[59,7002,7004],{"className":7003},[153,154,155,156],[59,7005,7007],{"className":7006},[119,156],[59,7008,789],{"className":7009},[119,192,156],[59,7011,227],{"className":7012},[226],[59,7014,7016],{"className":7015},[136],[59,7017,7019],{"className":7018,"style":1270},[140],[59,7020],{}," 的线性函数。",[2390,7023,7025],{"className":2392,"code":7024,"language":2394,"meta":2395,"style":2395},"class PositionalEncoding(nn.Block):\n    \"\"\"位置编码\"\"\"\n    def __init__(self, num_hiddens, dropout, max_len=1000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(dropout)\n        # 创建一个足够长的P\n        self.P = np.zeros((1, max_len, num_hiddens))\n        X = np.arange(max_len).reshape(-1, 1) / np.power(\n            10000, np.arange(0, num_hiddens, 2) / num_hiddens)\n        self.P[:, :, 0::2] = np.sin(X)\n        self.P[:, :, 1::2] = np.cos(X)\n\n    def forward(self, X):\n        X = X + self.P[:, :X.shape[1], :].as_in_ctx(X.ctx)\n        return self.dropout(X)\n",[2397,7026,7027,7045,7050,7079,7095,7105,7110,7127,7152,7174,7206,7233,7237,7253,7286],{"__ignoreMap":2395},[59,7028,7029,7031,7034,7036,7038,7040,7043],{"class":2401,"line":2402},[59,7030,2630],{"class":2405},[59,7032,7033],{"class":2633}," PositionalEncoding",[59,7035,798],{"class":2413},[59,7037,2640],{"class":2639},[59,7039,864],{"class":2413},[59,7041,7042],{"class":2639},"Block",[59,7044,2426],{"class":2413},[59,7046,7047],{"class":2401,"line":2429},[59,7048,7049],{"class":2432},"    \"\"\"位置编码\"\"\"\n",[59,7051,7052,7054,7056,7058,7060,7062,7064,7066,7068,7070,7073,7075,7077],{"class":2401,"line":2436},[59,7053,2657],{"class":2405},[59,7055,2660],{"class":2455},[59,7057,798],{"class":2413},[59,7059,2666],{"class":2665},[59,7061,2420],{"class":2413},[59,7063,4862],{"class":2416},[59,7065,2420],{"class":2413},[59,7067,2671],{"class":2416},[59,7069,2420],{"class":2413},[59,7071,7072],{"class":2416},"max_len",[59,7074,815],{"class":2405},[59,7076,6169],{"class":2455},[59,7078,2426],{"class":2413},[59,7080,7081,7083,7086,7088,7090,7092],{"class":2401,"line":2443},[59,7082,2687],{"class":2686},[59,7084,7085],{"class":2413},"(PositionalEncoding, ",[59,7087,2666],{"class":2693},[59,7089,2696],{"class":2413},[59,7091,2699],{"class":2455},[59,7093,7094],{"class":2413},"()\n",[59,7096,7097,7099,7101,7103],{"class":2401,"line":2462},[59,7098,2711],{"class":2693},[59,7100,2714],{"class":2413},[59,7102,815],{"class":2405},[59,7104,2719],{"class":2413},[59,7106,7107],{"class":2401,"line":2483},[59,7108,7109],{"class":2439},"        # 创建一个足够长的P\n",[59,7111,7112,7114,7117,7119,7122,7124],{"class":2401,"line":2491},[59,7113,2711],{"class":2693},[59,7115,7116],{"class":2413},".P ",[59,7118,815],{"class":2405},[59,7120,7121],{"class":2413}," np.zeros((",[59,7123,83],{"class":2455},[59,7125,7126],{"class":2413},", max_len, num_hiddens))\n",[59,7128,7129,7131,7133,7136,7138,7140,7142,7144,7147,7149],{"class":2401,"line":2502},[59,7130,2569],{"class":2413},[59,7132,815],{"class":2405},[59,7134,7135],{"class":2413}," np.arange(max_len).reshape(",[59,7137,2553],{"class":2405},[59,7139,83],{"class":2455},[59,7141,2420],{"class":2413},[59,7143,83],{"class":2455},[59,7145,7146],{"class":2413},") ",[59,7148,2827],{"class":2405},[59,7150,7151],{"class":2413}," np.power(\n",[59,7153,7154,7157,7160,7162,7165,7167,7169,7171],{"class":2401,"line":2519},[59,7155,7156],{"class":2455},"            10000",[59,7158,7159],{"class":2413},", np.arange(",[59,7161,1754],{"class":2455},[59,7163,7164],{"class":2413},", num_hiddens, ",[59,7166,2821],{"class":2455},[59,7168,7146],{"class":2413},[59,7170,2827],{"class":2405},[59,7172,7173],{"class":2413}," num_hiddens)\n",[59,7175,7176,7178,7181,7185,7187,7189,7191,7193,7196,7198,7201,7203],{"class":2401,"line":2535},[59,7177,2711],{"class":2693},[59,7179,7180],{"class":2413},".P[",[59,7182,7184],{"class":7183},"sDoOe",":",[59,7186,2420],{"class":2413},[59,7188,7184],{"class":7183},[59,7190,2420],{"class":2413},[59,7192,1754],{"class":2455},[59,7194,7195],{"class":7183},"::",[59,7197,2821],{"class":2455},[59,7199,7200],{"class":2413},"] ",[59,7202,815],{"class":2405},[59,7204,7205],{"class":2413}," np.sin(X)\n",[59,7207,7208,7210,7212,7214,7216,7218,7220,7222,7224,7226,7228,7230],{"class":2401,"line":2543},[59,7209,2711],{"class":2693},[59,7211,7180],{"class":2413},[59,7213,7184],{"class":7183},[59,7215,2420],{"class":2413},[59,7217,7184],{"class":7183},[59,7219,2420],{"class":2413},[59,7221,83],{"class":2455},[59,7223,7195],{"class":7183},[59,7225,2821],{"class":2455},[59,7227,7200],{"class":2413},[59,7229,815],{"class":2405},[59,7231,7232],{"class":2413}," np.cos(X)\n",[59,7234,7235],{"class":2401,"line":2560},[59,7236,2725],{"emptyLinePlaceholder":2724},[59,7238,7239,7241,7243,7245,7247,7249,7251],{"class":2401,"line":2566},[59,7240,2657],{"class":2405},[59,7242,2752],{"class":2409},[59,7244,798],{"class":2413},[59,7246,2666],{"class":2665},[59,7248,2420],{"class":2413},[59,7250,2417],{"class":2416},[59,7252,2426],{"class":2413},[59,7254,7255,7257,7259,7262,7264,7266,7268,7270,7272,7274,7277,7279,7281,7283],{"class":2401,"line":2591},[59,7256,2569],{"class":2413},[59,7258,815],{"class":2405},[59,7260,7261],{"class":2413}," X ",[59,7263,5777],{"class":2405},[59,7265,5220],{"class":2693},[59,7267,7180],{"class":2413},[59,7269,7184],{"class":7183},[59,7271,2420],{"class":2413},[59,7273,7184],{"class":7183},[59,7275,7276],{"class":2413},"X.shape[",[59,7278,83],{"class":2455},[59,7280,4794],{"class":2413},[59,7282,7184],{"class":7183},[59,7284,7285],{"class":2413},"].as_in_ctx(X.ctx)\n",[59,7287,7288,7290,7292],{"class":2401,"line":2604},[59,7289,2465],{"class":2405},[59,7291,5220],{"class":2693},[59,7293,7294],{"class":2413},".dropout(X)\n",[11,7296,7297],{"id":7297},"前馈神经网络",[15,7299,7300],{},"除了注意力子层外，我们的编码器和解码器的每一层中都包含一个全连接的前馈神经网络，该网络分别且相同地应用于每个位置。这包括两个线性变换，中间使用ReLU激活函数。",[15,7302,7303],{},[59,7304,7306,7379],{"className":7305},[62],[59,7307,7309],{"className":7308},[66],[68,7310,7311],{"xmlns":70},[72,7312,7313,7376],{},[75,7314,7315,7318,7320,7323,7325,7328,7330,7332,7335,7337,7339,7341,7343,7345,7351,7353,7360,7362,7368,7370],{},[91,7316,7317],{},"F",[91,7319,7317],{},[91,7321,7322],{},"N",[461,7324,798],{"stretchy":797},[91,7326,7327],{},"x",[461,7329,812],{"stretchy":797},[461,7331,815],{},[91,7333,7334],{},"max",[461,7336,822],{},[461,7338,798],{"stretchy":797},[81,7340,1754],{},[461,7342,803],{"separator":106},[91,7344,7327],{},[88,7346,7347,7349],{},[91,7348,2993],{},[81,7350,83],{},[461,7352,5777],{},[88,7354,7355,7358],{},[91,7356,7357],{},"b",[81,7359,83],{},[461,7361,812],{"stretchy":797},[88,7363,7364,7366],{},[91,7365,2993],{},[81,7367,2821],{},[461,7369,5777],{},[88,7371,7372,7374],{},[91,7373,7357],{},[81,7375,2821],{},[98,7377,7378],{"encoding":100}," FFN(x) = \\max(0,xW_1+b_1)W_2+b_2",[59,7380,7382,7416,7490,7588],{"className":7381,"ariaHidden":106},[105],[59,7383,7385,7388,7391,7394,7398,7401,7404,7407,7410,7413],{"className":7384},[110],[59,7386],{"className":7387,"style":877},[114],[59,7389,7317],{"className":7390,"style":3185},[119,192],[59,7392,7317],{"className":7393,"style":3185},[119,192],[59,7395,7322],{"className":7396,"style":7397},[119,192],"margin-right:0.10903em;",[59,7399,798],{"className":7400},[123],[59,7402,7327],{"className":7403},[119,192],[59,7405,812],{"className":7406},[313],[59,7408],{"className":7409,"style":500},[499],[59,7411,815],{"className":7412},[504],[59,7414],{"className":7415,"style":500},[499],[59,7417,7419,7422,7425,7428,7431,7434,7437,7440,7481,7484,7487],{"className":7418},[110],[59,7420],{"className":7421,"style":877},[114],[59,7423,7334],{"className":7424},[951],[59,7426,798],{"className":7427},[123],[59,7429,1754],{"className":7430},[119],[59,7432,803],{"className":7433},[912],[59,7435],{"className":7436,"style":916},[499],[59,7438,7327],{"className":7439},[119,192],[59,7441,7443,7446],{"className":7442},[119],[59,7444,2993],{"className":7445,"style":3185},[119,192],[59,7447,7449],{"className":7448},[196],[59,7450,7452,7473],{"className":7451},[131,132],[59,7453,7455,7470],{"className":7454},[136],[59,7456,7458],{"className":7457,"style":3091},[140],[59,7459,7461,7464],{"style":7460},"top:-2.55em;margin-left:-0.1389em;margin-right:0.05em;",[59,7462],{"className":7463,"style":540},[148],[59,7465,7467],{"className":7466},[153,154,155,156],[59,7468,83],{"className":7469},[119,156],[59,7471,227],{"className":7472},[226],[59,7474,7476],{"className":7475},[136],[59,7477,7479],{"className":7478,"style":1270},[140],[59,7480],{},[59,7482],{"className":7483,"style":1872},[499],[59,7485,5777],{"className":7486},[553],[59,7488],{"className":7489,"style":1872},[499],[59,7491,7493,7496,7536,7539,7579,7582,7585],{"className":7492},[110],[59,7494],{"className":7495,"style":877},[114],[59,7497,7499,7502],{"className":7498},[119],[59,7500,7357],{"className":7501},[119,192],[59,7503,7505],{"className":7504},[196],[59,7506,7508,7528],{"className":7507},[131,132],[59,7509,7511,7525],{"className":7510},[136],[59,7512,7514],{"className":7513,"style":3091},[140],[59,7515,7516,7519],{"style":1251},[59,7517],{"className":7518,"style":540},[148],[59,7520,7522],{"className":7521},[153,154,155,156],[59,7523,83],{"className":7524},[119,156],[59,7526,227],{"className":7527},[226],[59,7529,7531],{"className":7530},[136],[59,7532,7534],{"className":7533,"style":1270},[140],[59,7535],{},[59,7537,812],{"className":7538},[313],[59,7540,7542,7545],{"className":7541},[119],[59,7543,2993],{"className":7544,"style":3185},[119,192],[59,7546,7548],{"className":7547},[196],[59,7549,7551,7571],{"className":7550},[131,132],[59,7552,7554,7568],{"className":7553},[136],[59,7555,7557],{"className":7556,"style":3091},[140],[59,7558,7559,7562],{"style":7460},[59,7560],{"className":7561,"style":540},[148],[59,7563,7565],{"className":7564},[153,154,155,156],[59,7566,2821],{"className":7567},[119,156],[59,7569,227],{"className":7570},[226],[59,7572,7574],{"className":7573},[136],[59,7575,7577],{"className":7576,"style":1270},[140],[59,7578],{},[59,7580],{"className":7581,"style":1872},[499],[59,7583,5777],{"className":7584},[553],[59,7586],{"className":7587,"style":1872},[499],[59,7589,7591,7594],{"className":7590},[110],[59,7592],{"className":7593,"style":1229},[114],[59,7595,7597,7600],{"className":7596},[119],[59,7598,7357],{"className":7599},[119,192],[59,7601,7603],{"className":7602},[196],[59,7604,7606,7626],{"className":7605},[131,132],[59,7607,7609,7623],{"className":7608},[136],[59,7610,7612],{"className":7611,"style":3091},[140],[59,7613,7614,7617],{"style":1251},[59,7615],{"className":7616,"style":540},[148],[59,7618,7620],{"className":7619},[153,154,155,156],[59,7621,2821],{"className":7622},[119,156],[59,7624,227],{"className":7625},[226],[59,7627,7629],{"className":7628},[136],[59,7630,7632],{"className":7631,"style":1270},[140],[59,7633],{},[2390,7635,7637],{"className":2392,"code":7636,"language":2394,"meta":2395,"style":2395},"class PositionWiseFFN(nn.Module):\n    \"\"\"基于位置的前馈网络\"\"\"\n    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,\n                 **kwargs):\n        super(PositionWiseFFN, self).__init__(**kwargs)\n        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)\n        self.relu = nn.ReLU()\n        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)\n\n    def forward(self, X):\n        return self.dense2(self.relu(self.dense1(X)))\n",[2397,7638,7639,7656,7661,7688,7697,7716,7728,7740,7752,7756,7772],{"__ignoreMap":2395},[59,7640,7641,7643,7646,7648,7650,7652,7654],{"class":2401,"line":2402},[59,7642,2630],{"class":2405},[59,7644,7645],{"class":2633}," PositionWiseFFN",[59,7647,798],{"class":2413},[59,7649,2640],{"class":2639},[59,7651,864],{"class":2413},[59,7653,2645],{"class":2639},[59,7655,2426],{"class":2413},[59,7657,7658],{"class":2401,"line":2429},[59,7659,7660],{"class":2432},"    \"\"\"基于位置的前馈网络\"\"\"\n",[59,7662,7663,7665,7667,7669,7671,7673,7676,7678,7681,7683,7686],{"class":2401,"line":2436},[59,7664,2657],{"class":2405},[59,7666,2660],{"class":2455},[59,7668,798],{"class":2413},[59,7670,2666],{"class":2665},[59,7672,2420],{"class":2413},[59,7674,7675],{"class":2416},"ffn_num_input",[59,7677,2420],{"class":2413},[59,7679,7680],{"class":2416},"ffn_num_hiddens",[59,7682,2420],{"class":2413},[59,7684,7685],{"class":2416},"ffn_num_outputs",[59,7687,3826],{"class":2413},[59,7689,7690,7693,7695],{"class":2401,"line":2443},[59,7691,7692],{"class":2405},"                 **",[59,7694,2679],{"class":2416},[59,7696,2426],{"class":2413},[59,7698,7699,7701,7704,7706,7708,7710,7712,7714],{"class":2401,"line":2462},[59,7700,2687],{"class":2686},[59,7702,7703],{"class":2413},"(PositionWiseFFN, ",[59,7705,2666],{"class":2693},[59,7707,2696],{"class":2413},[59,7709,2699],{"class":2455},[59,7711,798],{"class":2413},[59,7713,2676],{"class":2405},[59,7715,2706],{"class":2413},[59,7717,7718,7720,7723,7725],{"class":2401,"line":2483},[59,7719,2711],{"class":2693},[59,7721,7722],{"class":2413},".dense1 ",[59,7724,815],{"class":2405},[59,7726,7727],{"class":2413}," nn.Linear(ffn_num_input, ffn_num_hiddens)\n",[59,7729,7730,7732,7735,7737],{"class":2401,"line":2491},[59,7731,2711],{"class":2693},[59,7733,7734],{"class":2413},".relu ",[59,7736,815],{"class":2405},[59,7738,7739],{"class":2413}," nn.ReLU()\n",[59,7741,7742,7744,7747,7749],{"class":2401,"line":2502},[59,7743,2711],{"class":2693},[59,7745,7746],{"class":2413},".dense2 ",[59,7748,815],{"class":2405},[59,7750,7751],{"class":2413}," nn.Linear(ffn_num_hiddens, ffn_num_outputs)\n",[59,7753,7754],{"class":2401,"line":2519},[59,7755,2725],{"emptyLinePlaceholder":2724},[59,7757,7758,7760,7762,7764,7766,7768,7770],{"class":2401,"line":2535},[59,7759,2657],{"class":2405},[59,7761,2752],{"class":2409},[59,7763,798],{"class":2413},[59,7765,2666],{"class":2665},[59,7767,2420],{"class":2413},[59,7769,2417],{"class":2416},[59,7771,2426],{"class":2413},[59,7773,7774,7776,7778,7781,7783,7786,7788],{"class":2401,"line":2543},[59,7775,2465],{"class":2405},[59,7777,5220],{"class":2693},[59,7779,7780],{"class":2413},".dense2(",[59,7782,2666],{"class":2693},[59,7784,7785],{"class":2413},".relu(",[59,7787,2666],{"class":2693},[59,7789,7790],{"class":2413},".dense1(X)))\n",[11,7792,7793],{"id":7793},"残差连接和层规范化",[15,7795,7796],{},"我们在每个子层之间加入了一个残差连接和层归一化。",[15,7798,7799,7800,4568,7925,7988],{},"因此，每个子层的输出都是 ",[59,7801,7803,7871],{"className":7802},[62],[59,7804,7806],{"className":7805},[66],[68,7807,7808],{"xmlns":70},[72,7809,7810,7868],{},[75,7811,7812,7835,7837,7839,7841,7860,7862,7864,7866],{},[75,7813,7814,7817,7819,7822,7824,7827,7829,7831,7833],{},[91,7815,7816],{"mathvariant":818},"L",[91,7818,22],{"mathvariant":818},[91,7820,7821],{"mathvariant":818},"y",[91,7823,782],{"mathvariant":818},[91,7825,7826],{"mathvariant":818},"r",[91,7828,7322],{"mathvariant":818},[91,7830,792],{"mathvariant":818},[91,7832,7826],{"mathvariant":818},[91,7834,366],{"mathvariant":818},[461,7836,798],{"stretchy":797},[91,7838,7327],{},[461,7840,5777],{},[75,7842,7843,7846,7848,7850,7852,7854,7856,7858],{},[91,7844,7845],{"mathvariant":818},"S",[91,7847,2900],{"mathvariant":818},[91,7849,7357],{"mathvariant":818},[91,7851,2903],{"mathvariant":818},[91,7853,22],{"mathvariant":818},[91,7855,7821],{"mathvariant":818},[91,7857,782],{"mathvariant":818},[91,7859,7826],{"mathvariant":818},[461,7861,798],{"stretchy":797},[91,7863,7327],{},[461,7865,812],{"stretchy":797},[461,7867,812],{"stretchy":797},[98,7869,7870],{"encoding":100},"\\mathrm{LayerNorm}(x + \\mathrm{Sublayer}(x))",[59,7872,7874,7902],{"className":7873,"ariaHidden":106},[105],[59,7875,7877,7880,7887,7890,7893,7896,7899],{"className":7876},[110],[59,7878],{"className":7879,"style":877},[114],[59,7881,7883],{"className":7882},[119],[59,7884,7886],{"className":7885},[119,955],"LayerNorm",[59,7888,798],{"className":7889},[123],[59,7891,7327],{"className":7892},[119,192],[59,7894],{"className":7895,"style":1872},[499],[59,7897,5777],{"className":7898},[553],[59,7900],{"className":7901,"style":1872},[499],[59,7903,7905,7908,7915,7918,7921],{"className":7904},[110],[59,7906],{"className":7907,"style":877},[114],[59,7909,7911],{"className":7910},[119],[59,7912,7914],{"className":7913},[119,955],"Sublayer",[59,7916,798],{"className":7917},[123],[59,7919,7327],{"className":7920},[119,192],[59,7922,7924],{"className":7923},[313],"))",[59,7926,7928,7964],{"className":7927},[62],[59,7929,7931],{"className":7930},[66],[68,7932,7933],{"xmlns":70},[72,7934,7935,7961],{},[75,7936,7937,7955,7957,7959],{},[75,7938,7939,7941,7943,7945,7947,7949,7951,7953],{},[91,7940,7845],{"mathvariant":818},[91,7942,2900],{"mathvariant":818},[91,7944,7357],{"mathvariant":818},[91,7946,2903],{"mathvariant":818},[91,7948,22],{"mathvariant":818},[91,7950,7821],{"mathvariant":818},[91,7952,782],{"mathvariant":818},[91,7954,7826],{"mathvariant":818},[461,7956,798],{"stretchy":797},[91,7958,7327],{},[461,7960,812],{"stretchy":797},[98,7962,7963],{"encoding":100},"\\mathrm{Sublayer}(x)",[59,7965,7967],{"className":7966,"ariaHidden":106},[105],[59,7968,7970,7973,7979,7982,7985],{"className":7969},[110],[59,7971],{"className":7972,"style":877},[114],[59,7974,7976],{"className":7975},[119],[59,7977,7914],{"className":7978},[119,955],[59,7980,798],{"className":7981},[123],[59,7983,7327],{"className":7984},[119,192],[59,7986,812],{"className":7987},[313]," 是由该子层本身实现的函数。我们在每个子层的输出上应用 dropout，然后再将其加到子层输入并进行归一化。",[2390,7990,7992],{"className":2392,"code":7991,"language":2394,"meta":2395,"style":2395},"class AddNorm(nn.Module):\n    \"\"\"残差连接后进行层规范化\"\"\"\n    def __init__(self, normalized_shape, dropout, **kwargs):\n        super(AddNorm, self).__init__(**kwargs)\n        self.dropout = nn.Dropout(dropout)\n        self.ln = nn.LayerNorm(normalized_shape)\n\n    def forward(self, X, Y):\n        return self.ln(self.dropout(Y) + X)\n",[2397,7993,7994,8011,8016,8043,8062,8072,8084,8088,8109],{"__ignoreMap":2395},[59,7995,7996,7998,8001,8003,8005,8007,8009],{"class":2401,"line":2402},[59,7997,2630],{"class":2405},[59,7999,8000],{"class":2633}," AddNorm",[59,8002,798],{"class":2413},[59,8004,2640],{"class":2639},[59,8006,864],{"class":2413},[59,8008,2645],{"class":2639},[59,8010,2426],{"class":2413},[59,8012,8013],{"class":2401,"line":2429},[59,8014,8015],{"class":2432},"    \"\"\"残差连接后进行层规范化\"\"\"\n",[59,8017,8018,8020,8022,8024,8026,8028,8031,8033,8035,8037,8039,8041],{"class":2401,"line":2436},[59,8019,2657],{"class":2405},[59,8021,2660],{"class":2455},[59,8023,798],{"class":2413},[59,8025,2666],{"class":2665},[59,8027,2420],{"class":2413},[59,8029,8030],{"class":2416},"normalized_shape",[59,8032,2420],{"class":2413},[59,8034,2671],{"class":2416},[59,8036,2420],{"class":2413},[59,8038,2676],{"class":2405},[59,8040,2679],{"class":2416},[59,8042,2426],{"class":2413},[59,8044,8045,8047,8050,8052,8054,8056,8058,8060],{"class":2401,"line":2443},[59,8046,2687],{"class":2686},[59,8048,8049],{"class":2413},"(AddNorm, ",[59,8051,2666],{"class":2693},[59,8053,2696],{"class":2413},[59,8055,2699],{"class":2455},[59,8057,798],{"class":2413},[59,8059,2676],{"class":2405},[59,8061,2706],{"class":2413},[59,8063,8064,8066,8068,8070],{"class":2401,"line":2462},[59,8065,2711],{"class":2693},[59,8067,2714],{"class":2413},[59,8069,815],{"class":2405},[59,8071,2719],{"class":2413},[59,8073,8074,8076,8079,8081],{"class":2401,"line":2483},[59,8075,2711],{"class":2693},[59,8077,8078],{"class":2413},".ln ",[59,8080,815],{"class":2405},[59,8082,8083],{"class":2413}," nn.LayerNorm(normalized_shape)\n",[59,8085,8086],{"class":2401,"line":2491},[59,8087,2725],{"emptyLinePlaceholder":2724},[59,8089,8090,8092,8094,8096,8098,8100,8102,8104,8107],{"class":2401,"line":2502},[59,8091,2657],{"class":2405},[59,8093,2752],{"class":2409},[59,8095,798],{"class":2413},[59,8097,2666],{"class":2665},[59,8099,2420],{"class":2413},[59,8101,2417],{"class":2416},[59,8103,2420],{"class":2413},[59,8105,8106],{"class":2416},"Y",[59,8108,2426],{"class":2413},[59,8110,8111,8113,8115,8118,8120,8123,8125],{"class":2401,"line":2519},[59,8112,2465],{"class":2405},[59,8114,5220],{"class":2693},[59,8116,8117],{"class":2413},".ln(",[59,8119,2666],{"class":2693},[59,8121,8122],{"class":2413},".dropout(Y) ",[59,8124,5777],{"class":2405},[59,8126,8127],{"class":2413}," X)\n",[11,8129,8130],{"id":8130},"编码器-解码器架构",[15,8132,8133,8134,8284,8285,8461,8462,8491,8492,8642],{},"大多数主流的神经序列转换模型都采用编码器—解码器结构。其中，编码器将符号表示组成的输入序列 ",[59,8135,8137,8171],{"className":8136},[62],[59,8138,8140],{"className":8139},[66],[68,8141,8142],{"xmlns":70},[72,8143,8144,8168],{},[75,8145,8146,8148,8154,8156,8158,8160,8166],{},[461,8147,798],{"stretchy":797},[88,8149,8150,8152],{},[91,8151,7327],{},[81,8153,83],{},[461,8155,803],{"separator":106},[461,8157,2970],{},[461,8159,803],{"separator":106},[88,8161,8162,8164],{},[91,8163,7327],{},[91,8165,335],{},[461,8167,812],{"stretchy":797},[98,8169,8170],{"encoding":100},"(x_1, \\ldots, x_n)",[59,8172,8174],{"className":8173,"ariaHidden":106},[105],[59,8175,8177,8180,8183,8223,8226,8229,8232,8235,8238,8241,8281],{"className":8176},[110],[59,8178],{"className":8179,"style":877},[114],[59,8181,798],{"className":8182},[123],[59,8184,8186,8189],{"className":8185},[119],[59,8187,7327],{"className":8188},[119,192],[59,8190,8192],{"className":8191},[196],[59,8193,8195,8215],{"className":8194},[131,132],[59,8196,8198,8212],{"className":8197},[136],[59,8199,8201],{"className":8200,"style":3091},[140],[59,8202,8203,8206],{"style":1251},[59,8204],{"className":8205,"style":540},[148],[59,8207,8209],{"className":8208},[153,154,155,156],[59,8210,83],{"className":8211},[119,156],[59,8213,227],{"className":8214},[226],[59,8216,8218],{"className":8217},[136],[59,8219,8221],{"className":8220,"style":1270},[140],[59,8222],{},[59,8224,803],{"className":8225},[912],[59,8227],{"className":8228,"style":916},[499],[59,8230,2970],{"className":8231},[962],[59,8233],{"className":8234,"style":916},[499],[59,8236,803],{"className":8237},[912],[59,8239],{"className":8240,"style":916},[499],[59,8242,8244,8247],{"className":8243},[119],[59,8245,7327],{"className":8246},[119,192],[59,8248,8250],{"className":8249},[196],[59,8251,8253,8273],{"className":8252},[131,132],[59,8254,8256,8270],{"className":8255},[136],[59,8257,8259],{"className":8258,"style":1322},[140],[59,8260,8261,8264],{"style":1251},[59,8262],{"className":8263,"style":540},[148],[59,8265,8267],{"className":8266},[153,154,155,156],[59,8268,335],{"className":8269},[119,192,156],[59,8271,227],{"className":8272},[226],[59,8274,8276],{"className":8275},[136],[59,8277,8279],{"className":8278,"style":1270},[140],[59,8280],{},[59,8282,812],{"className":8283},[313]," 映射为连续表示序列 ",[59,8286,8288,8327],{"className":8287},[62],[59,8289,8291],{"className":8290},[66],[68,8292,8293],{"xmlns":70},[72,8294,8295,8324],{},[75,8296,8297,8300,8302,8304,8310,8312,8314,8316,8322],{},[91,8298,8299],{"mathvariant":5378},"z",[461,8301,815],{},[461,8303,798],{"stretchy":797},[88,8305,8306,8308],{},[91,8307,8299],{},[81,8309,83],{},[461,8311,803],{"separator":106},[461,8313,2970],{},[461,8315,803],{"separator":106},[88,8317,8318,8320],{},[91,8319,8299],{},[91,8321,335],{},[461,8323,812],{"stretchy":797},[98,8325,8326],{"encoding":100},"\\mathbf{z} = (z_1, \\ldots, z_n)",[59,8328,8330,8349],{"className":8329,"ariaHidden":106},[105],[59,8331,8333,8337,8340,8343,8346],{"className":8332},[110],[59,8334],{"className":8335,"style":8336},[114],"height:0.4444em;",[59,8338,8299],{"className":8339},[119,5423],[59,8341],{"className":8342,"style":500},[499],[59,8344,815],{"className":8345},[504],[59,8347],{"className":8348,"style":500},[499],[59,8350,8352,8355,8358,8400,8403,8406,8409,8412,8415,8418,8458],{"className":8351},[110],[59,8353],{"className":8354,"style":877},[114],[59,8356,798],{"className":8357},[123],[59,8359,8361,8365],{"className":8360},[119],[59,8362,8299],{"className":8363,"style":8364},[119,192],"margin-right:0.04398em;",[59,8366,8368],{"className":8367},[196],[59,8369,8371,8392],{"className":8370},[131,132],[59,8372,8374,8389],{"className":8373},[136],[59,8375,8377],{"className":8376,"style":3091},[140],[59,8378,8380,8383],{"style":8379},"top:-2.55em;margin-left:-0.044em;margin-right:0.05em;",[59,8381],{"className":8382,"style":540},[148],[59,8384,8386],{"className":8385},[153,154,155,156],[59,8387,83],{"className":8388},[119,156],[59,8390,227],{"className":8391},[226],[59,8393,8395],{"className":8394},[136],[59,8396,8398],{"className":8397,"style":1270},[140],[59,8399],{},[59,8401,803],{"className":8402},[912],[59,8404],{"className":8405,"style":916},[499],[59,8407,2970],{"className":8408},[962],[59,8410],{"className":8411,"style":916},[499],[59,8413,803],{"className":8414},[912],[59,8416],{"className":8417,"style":916},[499],[59,8419,8421,8424],{"className":8420},[119],[59,8422,8299],{"className":8423,"style":8364},[119,192],[59,8425,8427],{"className":8426},[196],[59,8428,8430,8450],{"className":8429},[131,132],[59,8431,8433,8447],{"className":8432},[136],[59,8434,8436],{"className":8435,"style":1322},[140],[59,8437,8438,8441],{"style":8379},[59,8439],{"className":8440,"style":540},[148],[59,8442,8444],{"className":8443},[153,154,155,156],[59,8445,335],{"className":8446},[119,192,156],[59,8448,227],{"className":8449},[226],[59,8451,8453],{"className":8452},[136],[59,8454,8456],{"className":8455,"style":1270},[140],[59,8457],{},[59,8459,812],{"className":8460},[313],"。给定 ",[59,8463,8465,8479],{"className":8464},[62],[59,8466,8468],{"className":8467},[66],[68,8469,8470],{"xmlns":70},[72,8471,8472,8476],{},[75,8473,8474],{},[91,8475,8299],{"mathvariant":5378},[98,8477,8478],{"encoding":100},"\\mathbf{z}",[59,8480,8482],{"className":8481,"ariaHidden":106},[105],[59,8483,8485,8488],{"className":8484},[110],[59,8486],{"className":8487,"style":8336},[114],[59,8489,8299],{"className":8490},[119,5423],"，解码器随后一次生成一个元素的输出符号序列 ",[59,8493,8495,8529],{"className":8494},[62],[59,8496,8498],{"className":8497},[66],[68,8499,8500],{"xmlns":70},[72,8501,8502,8526],{},[75,8503,8504,8506,8512,8514,8516,8518,8524],{},[461,8505,798],{"stretchy":797},[88,8507,8508,8510],{},[91,8509,7821],{},[81,8511,83],{},[461,8513,803],{"separator":106},[461,8515,2970],{},[461,8517,803],{"separator":106},[88,8519,8520,8522],{},[91,8521,7821],{},[91,8523,366],{},[461,8525,812],{"stretchy":797},[98,8527,8528],{"encoding":100},"(y_1, \\ldots, y_m)",[59,8530,8532],{"className":8531,"ariaHidden":106},[105],[59,8533,8535,8538,8541,8581,8584,8587,8590,8593,8596,8599,8639],{"className":8534},[110],[59,8536],{"className":8537,"style":877},[114],[59,8539,798],{"className":8540},[123],[59,8542,8544,8547],{"className":8543},[119],[59,8545,7821],{"className":8546,"style":441},[119,192],[59,8548,8550],{"className":8549},[196],[59,8551,8553,8573],{"className":8552},[131,132],[59,8554,8556,8570],{"className":8555},[136],[59,8557,8559],{"className":8558,"style":3091},[140],[59,8560,8561,8564],{"style":2035},[59,8562],{"className":8563,"style":540},[148],[59,8565,8567],{"className":8566},[153,154,155,156],[59,8568,83],{"className":8569},[119,156],[59,8571,227],{"className":8572},[226],[59,8574,8576],{"className":8575},[136],[59,8577,8579],{"className":8578,"style":1270},[140],[59,8580],{},[59,8582,803],{"className":8583},[912],[59,8585],{"className":8586,"style":916},[499],[59,8588,2970],{"className":8589},[962],[59,8591],{"className":8592,"style":916},[499],[59,8594,803],{"className":8595},[912],[59,8597],{"className":8598,"style":916},[499],[59,8600,8602,8605],{"className":8601},[119],[59,8603,7821],{"className":8604,"style":441},[119,192],[59,8606,8608],{"className":8607},[196],[59,8609,8611,8631],{"className":8610},[131,132],[59,8612,8614,8628],{"className":8613},[136],[59,8615,8617],{"className":8616,"style":1322},[140],[59,8618,8619,8622],{"style":2035},[59,8620],{"className":8621,"style":540},[148],[59,8623,8625],{"className":8624},[153,154,155,156],[59,8626,366],{"className":8627},[119,192,156],[59,8629,227],{"className":8630},[226],[59,8632,8634],{"className":8633},[136],[59,8635,8637],{"className":8636,"style":1270},[140],[59,8638],{},[59,8640,812],{"className":8641},[313],"。在每一个时间步，模型都是自回归的，即在生成下一个符号时，会将之前已经生成的符号作为额外输入。",[2390,8644,8646],{"className":2392,"code":8645,"language":2394,"meta":2395,"style":2395},"class Encoder(nn.Module):\n    \"\"\"编码器-解码器架构的基本编码器接口\"\"\"\n    def __init__(self, **kwargs):\n        super(Encoder, self).__init__(**kwargs)\n\n    def forward(self, X, *args):\n        raise NotImplementedError\n\nclass Decoder(nn.Module):\n    \"\"\"编码器-解码器架构的基本解码器接口\"\"\"\n    def __init__(self, **kwargs):\n        super(Decoder, self).__init__(**kwargs)\n\n    def init_state(self, enc_outputs, *args):\n        raise NotImplementedError\n\n    def forward(self, X, state):\n        raise NotImplementedError\n\nclass EncoderDecoder(nn.Module):\n    \"\"\"编码器-解码器架构的基类\"\"\"\n    def __init__(self, encoder, decoder, **kwargs):\n        super(EncoderDecoder, self).__init__(**kwargs)\n        self.encoder = encoder\n        self.decoder = decoder\n\n    def forward(self, enc_X, dec_X, *args):\n        enc_outputs = self.encoder(enc_X, *args)\n        dec_state = self.decoder.init_state(enc_outputs, *args)\n        return self.decoder(dec_X, dec_state)\n",[2397,8647,8648,8665,8670,8688,8707,8711,8735,8743,8747,8764,8769,8787,8806,8810,8834,8840,8844,8865,8871,8875,8892,8897,8925,8944,8956,8968,8972,9000,9017,9033],{"__ignoreMap":2395},[59,8649,8650,8652,8655,8657,8659,8661,8663],{"class":2401,"line":2402},[59,8651,2630],{"class":2405},[59,8653,8654],{"class":2633}," Encoder",[59,8656,798],{"class":2413},[59,8658,2640],{"class":2639},[59,8660,864],{"class":2413},[59,8662,2645],{"class":2639},[59,8664,2426],{"class":2413},[59,8666,8667],{"class":2401,"line":2429},[59,8668,8669],{"class":2432},"    \"\"\"编码器-解码器架构的基本编码器接口\"\"\"\n",[59,8671,8672,8674,8676,8678,8680,8682,8684,8686],{"class":2401,"line":2436},[59,8673,2657],{"class":2405},[59,8675,2660],{"class":2455},[59,8677,798],{"class":2413},[59,8679,2666],{"class":2665},[59,8681,2420],{"class":2413},[59,8683,2676],{"class":2405},[59,8685,2679],{"class":2416},[59,8687,2426],{"class":2413},[59,8689,8690,8692,8695,8697,8699,8701,8703,8705],{"class":2401,"line":2443},[59,8691,2687],{"class":2686},[59,8693,8694],{"class":2413},"(Encoder, ",[59,8696,2666],{"class":2693},[59,8698,2696],{"class":2413},[59,8700,2699],{"class":2455},[59,8702,798],{"class":2413},[59,8704,2676],{"class":2405},[59,8706,2706],{"class":2413},[59,8708,8709],{"class":2401,"line":2462},[59,8710,2725],{"emptyLinePlaceholder":2724},[59,8712,8713,8715,8717,8719,8721,8723,8725,8727,8730,8733],{"class":2401,"line":2483},[59,8714,2657],{"class":2405},[59,8716,2752],{"class":2409},[59,8718,798],{"class":2413},[59,8720,2666],{"class":2665},[59,8722,2420],{"class":2413},[59,8724,2417],{"class":2416},[59,8726,2420],{"class":2413},[59,8728,8729],{"class":2405},"*",[59,8731,8732],{"class":2416},"args",[59,8734,2426],{"class":2413},[59,8736,8737,8740],{"class":2401,"line":2491},[59,8738,8739],{"class":2405},"        raise",[59,8741,8742],{"class":2686}," NotImplementedError\n",[59,8744,8745],{"class":2401,"line":2502},[59,8746,2725],{"emptyLinePlaceholder":2724},[59,8748,8749,8751,8754,8756,8758,8760,8762],{"class":2401,"line":2519},[59,8750,2630],{"class":2405},[59,8752,8753],{"class":2633}," Decoder",[59,8755,798],{"class":2413},[59,8757,2640],{"class":2639},[59,8759,864],{"class":2413},[59,8761,2645],{"class":2639},[59,8763,2426],{"class":2413},[59,8765,8766],{"class":2401,"line":2535},[59,8767,8768],{"class":2432},"    \"\"\"编码器-解码器架构的基本解码器接口\"\"\"\n",[59,8770,8771,8773,8775,8777,8779,8781,8783,8785],{"class":2401,"line":2543},[59,8772,2657],{"class":2405},[59,8774,2660],{"class":2455},[59,8776,798],{"class":2413},[59,8778,2666],{"class":2665},[59,8780,2420],{"class":2413},[59,8782,2676],{"class":2405},[59,8784,2679],{"class":2416},[59,8786,2426],{"class":2413},[59,8788,8789,8791,8794,8796,8798,8800,8802,8804],{"class":2401,"line":2560},[59,8790,2687],{"class":2686},[59,8792,8793],{"class":2413},"(Decoder, ",[59,8795,2666],{"class":2693},[59,8797,2696],{"class":2413},[59,8799,2699],{"class":2455},[59,8801,798],{"class":2413},[59,8803,2676],{"class":2405},[59,8805,2706],{"class":2413},[59,8807,8808],{"class":2401,"line":2566},[59,8809,2725],{"emptyLinePlaceholder":2724},[59,8811,8812,8814,8817,8819,8821,8823,8826,8828,8830,8832],{"class":2401,"line":2591},[59,8813,2657],{"class":2405},[59,8815,8816],{"class":2409}," init_state",[59,8818,798],{"class":2413},[59,8820,2666],{"class":2665},[59,8822,2420],{"class":2413},[59,8824,8825],{"class":2416},"enc_outputs",[59,8827,2420],{"class":2413},[59,8829,8729],{"class":2405},[59,8831,8732],{"class":2416},[59,8833,2426],{"class":2413},[59,8835,8836,8838],{"class":2401,"line":2604},[59,8837,8739],{"class":2405},[59,8839,8742],{"class":2686},[59,8841,8842],{"class":2401,"line":2845},[59,8843,2725],{"emptyLinePlaceholder":2724},[59,8845,8846,8848,8850,8852,8854,8856,8858,8860,8863],{"class":2401,"line":4781},[59,8847,2657],{"class":2405},[59,8849,2752],{"class":2409},[59,8851,798],{"class":2413},[59,8853,2666],{"class":2665},[59,8855,2420],{"class":2413},[59,8857,2417],{"class":2416},[59,8859,2420],{"class":2413},[59,8861,8862],{"class":2416},"state",[59,8864,2426],{"class":2413},[59,8866,8867,8869],{"class":2401,"line":5057},[59,8868,8739],{"class":2405},[59,8870,8742],{"class":2686},[59,8872,8873],{"class":2401,"line":5063},[59,8874,2725],{"emptyLinePlaceholder":2724},[59,8876,8877,8879,8882,8884,8886,8888,8890],{"class":2401,"line":5069},[59,8878,2630],{"class":2405},[59,8880,8881],{"class":2633}," EncoderDecoder",[59,8883,798],{"class":2413},[59,8885,2640],{"class":2639},[59,8887,864],{"class":2413},[59,8889,2645],{"class":2639},[59,8891,2426],{"class":2413},[59,8893,8894],{"class":2401,"line":5075},[59,8895,8896],{"class":2432},"    \"\"\"编码器-解码器架构的基类\"\"\"\n",[59,8898,8899,8901,8903,8905,8907,8909,8912,8914,8917,8919,8921,8923],{"class":2401,"line":5096},[59,8900,2657],{"class":2405},[59,8902,2660],{"class":2455},[59,8904,798],{"class":2413},[59,8906,2666],{"class":2665},[59,8908,2420],{"class":2413},[59,8910,8911],{"class":2416},"encoder",[59,8913,2420],{"class":2413},[59,8915,8916],{"class":2416},"decoder",[59,8918,2420],{"class":2413},[59,8920,2676],{"class":2405},[59,8922,2679],{"class":2416},[59,8924,2426],{"class":2413},[59,8926,8927,8929,8932,8934,8936,8938,8940,8942],{"class":2401,"line":5115},[59,8928,2687],{"class":2686},[59,8930,8931],{"class":2413},"(EncoderDecoder, ",[59,8933,2666],{"class":2693},[59,8935,2696],{"class":2413},[59,8937,2699],{"class":2455},[59,8939,798],{"class":2413},[59,8941,2676],{"class":2405},[59,8943,2706],{"class":2413},[59,8945,8946,8948,8951,8953],{"class":2401,"line":5134},[59,8947,2711],{"class":2693},[59,8949,8950],{"class":2413},".encoder ",[59,8952,815],{"class":2405},[59,8954,8955],{"class":2413}," encoder\n",[59,8957,8958,8960,8963,8965],{"class":2401,"line":5139},[59,8959,2711],{"class":2693},[59,8961,8962],{"class":2413},".decoder ",[59,8964,815],{"class":2405},[59,8966,8967],{"class":2413}," decoder\n",[59,8969,8970],{"class":2401,"line":5155},[59,8971,2725],{"emptyLinePlaceholder":2724},[59,8973,8974,8976,8978,8980,8982,8984,8987,8989,8992,8994,8996,8998],{"class":2401,"line":5161},[59,8975,2657],{"class":2405},[59,8977,2752],{"class":2409},[59,8979,798],{"class":2413},[59,8981,2666],{"class":2665},[59,8983,2420],{"class":2413},[59,8985,8986],{"class":2416},"enc_X",[59,8988,2420],{"class":2413},[59,8990,8991],{"class":2416},"dec_X",[59,8993,2420],{"class":2413},[59,8995,8729],{"class":2405},[59,8997,8732],{"class":2416},[59,8999,2426],{"class":2413},[59,9001,9002,9005,9007,9009,9012,9014],{"class":2401,"line":5167},[59,9003,9004],{"class":2413},"        enc_outputs ",[59,9006,815],{"class":2405},[59,9008,5220],{"class":2693},[59,9010,9011],{"class":2413},".encoder(enc_X, ",[59,9013,8729],{"class":2405},[59,9015,9016],{"class":2413},"args)\n",[59,9018,9019,9022,9024,9026,9029,9031],{"class":2401,"line":5177},[59,9020,9021],{"class":2413},"        dec_state ",[59,9023,815],{"class":2405},[59,9025,5220],{"class":2693},[59,9027,9028],{"class":2413},".decoder.init_state(enc_outputs, ",[59,9030,8729],{"class":2405},[59,9032,9016],{"class":2413},[59,9034,9035,9037,9039],{"class":2401,"line":5201},[59,9036,2465],{"class":2405},[59,9038,5220],{"class":2693},[59,9040,9041],{"class":2413},".decoder(dec_X, dec_state)\n",[11,9043,9045],{"id":9044},"transformer-编码器","transformer 编码器",[15,9047,9048],{},"代码实现如下：",[2390,9050,9052],{"className":2392,"code":9051,"language":2394,"meta":2395,"style":2395},"class EncoderBlock(nn.Module):\n    \"\"\"Transformer编码器块\"\"\"\n    def __init__(self, key_size, query_size, value_size, num_hiddens,\n                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n                 dropout, use_bias=False, **kwargs):\n        super(EncoderBlock, self).__init__(**kwargs)\n        self.attention = d2l.MultiHeadAttention(\n            key_size, query_size, value_size, num_hiddens, num_heads, dropout,\n            use_bias)\n        self.addnorm1 = AddNorm(norm_shape, dropout)\n        self.ffn = PositionWiseFFN(\n            ffn_num_input, ffn_num_hiddens, num_hiddens)\n        self.addnorm2 = AddNorm(norm_shape, dropout)\n\n    def forward(self, X, valid_lens):\n        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))\n        return self.addnorm2(Y, self.ffn(Y))\n\nclass TransformerEncoder(Encoder):\n    \"\"\"Transformer编码器\"\"\"\n    def __init__(self, vocab_size, key_size, query_size, value_size,\n                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,\n                 num_heads, num_layers, dropout, use_bias=False, **kwargs):\n        super(TransformerEncoder, self).__init__(**kwargs)\n        self.num_hiddens = num_hiddens\n        self.embedding = nn.Embedding(vocab_size, num_hiddens)\n        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)\n        self.blks = nn.Sequential()\n        for i in range(num_layers):\n            self.blks.add_module(\"block\"+str(i),\n                EncoderBlock(key_size, query_size, value_size, num_hiddens,\n                             norm_shape, ffn_num_input, ffn_num_hiddens,\n                             num_heads, dropout, use_bias))\n\n    def forward(self, X, valid_lens, *args):\n        # 因为位置编码值在-1和1之间，\n        # 因此嵌入值乘以嵌入维度的平方根进行缩放，\n        # 然后再与位置编码相加。\n        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n        self.attention_weights = [None] * len(self.blks)\n        for i, blk in enumerate(self.blks):\n            X = blk(X, valid_lens)\n            self.attention_weights[i] = blk.attention.attention.attention_weights\n        return X\n",[2397,9053,9054,9071,9076,9104,9123,9145,9164,9175,9180,9185,9197,9209,9214,9225,9229,9249,9266,9280,9284,9298,9303,9332,9352,9381,9400,9412,9424,9436,9448,9466,9492,9497,9502,9507,9511,9537,9542,9548,9554,9581,9609,9629,9640,9653],{"__ignoreMap":2395},[59,9055,9056,9058,9061,9063,9065,9067,9069],{"class":2401,"line":2402},[59,9057,2630],{"class":2405},[59,9059,9060],{"class":2633}," EncoderBlock",[59,9062,798],{"class":2413},[59,9064,2640],{"class":2639},[59,9066,864],{"class":2413},[59,9068,2645],{"class":2639},[59,9070,2426],{"class":2413},[59,9072,9073],{"class":2401,"line":2429},[59,9074,9075],{"class":2432},"    \"\"\"Transformer编码器块\"\"\"\n",[59,9077,9078,9080,9082,9084,9086,9088,9090,9092,9094,9096,9098,9100,9102],{"class":2401,"line":2436},[59,9079,2657],{"class":2405},[59,9081,2660],{"class":2455},[59,9083,798],{"class":2413},[59,9085,2666],{"class":2665},[59,9087,2420],{"class":2413},[59,9089,4847],{"class":2416},[59,9091,2420],{"class":2413},[59,9093,4852],{"class":2416},[59,9095,2420],{"class":2413},[59,9097,4857],{"class":2416},[59,9099,2420],{"class":2413},[59,9101,4862],{"class":2416},[59,9103,3826],{"class":2413},[59,9105,9106,9109,9111,9113,9115,9117,9119,9121],{"class":2401,"line":2443},[59,9107,9108],{"class":2416},"                 norm_shape",[59,9110,2420],{"class":2413},[59,9112,7675],{"class":2416},[59,9114,2420],{"class":2413},[59,9116,7680],{"class":2416},[59,9118,2420],{"class":2413},[59,9120,4595],{"class":2416},[59,9122,3826],{"class":2413},[59,9124,9125,9128,9130,9133,9135,9137,9139,9141,9143],{"class":2401,"line":2462},[59,9126,9127],{"class":2416},"                 dropout",[59,9129,2420],{"class":2413},[59,9131,9132],{"class":2416},"use_bias",[59,9134,815],{"class":2405},[59,9136,4883],{"class":2455},[59,9138,2420],{"class":2413},[59,9140,2676],{"class":2405},[59,9142,2679],{"class":2416},[59,9144,2426],{"class":2413},[59,9146,9147,9149,9152,9154,9156,9158,9160,9162],{"class":2401,"line":2483},[59,9148,2687],{"class":2686},[59,9150,9151],{"class":2413},"(EncoderBlock, ",[59,9153,2666],{"class":2693},[59,9155,2696],{"class":2413},[59,9157,2699],{"class":2455},[59,9159,798],{"class":2413},[59,9161,2676],{"class":2405},[59,9163,2706],{"class":2413},[59,9165,9166,9168,9170,9172],{"class":2401,"line":2491},[59,9167,2711],{"class":2693},[59,9169,4929],{"class":2413},[59,9171,815],{"class":2405},[59,9173,9174],{"class":2413}," d2l.MultiHeadAttention(\n",[59,9176,9177],{"class":2401,"line":2502},[59,9178,9179],{"class":2413},"            key_size, query_size, value_size, num_hiddens, num_heads, dropout,\n",[59,9181,9182],{"class":2401,"line":2519},[59,9183,9184],{"class":2413},"            use_bias)\n",[59,9186,9187,9189,9192,9194],{"class":2401,"line":2535},[59,9188,2711],{"class":2693},[59,9190,9191],{"class":2413},".addnorm1 ",[59,9193,815],{"class":2405},[59,9195,9196],{"class":2413}," AddNorm(norm_shape, dropout)\n",[59,9198,9199,9201,9204,9206],{"class":2401,"line":2543},[59,9200,2711],{"class":2693},[59,9202,9203],{"class":2413},".ffn ",[59,9205,815],{"class":2405},[59,9207,9208],{"class":2413}," PositionWiseFFN(\n",[59,9210,9211],{"class":2401,"line":2560},[59,9212,9213],{"class":2413},"            ffn_num_input, ffn_num_hiddens, num_hiddens)\n",[59,9215,9216,9218,9221,9223],{"class":2401,"line":2566},[59,9217,2711],{"class":2693},[59,9219,9220],{"class":2413},".addnorm2 ",[59,9222,815],{"class":2405},[59,9224,9196],{"class":2413},[59,9226,9227],{"class":2401,"line":2591},[59,9228,2725],{"emptyLinePlaceholder":2724},[59,9230,9231,9233,9235,9237,9239,9241,9243,9245,9247],{"class":2401,"line":2604},[59,9232,2657],{"class":2405},[59,9234,2752],{"class":2409},[59,9236,798],{"class":2413},[59,9238,2666],{"class":2665},[59,9240,2420],{"class":2413},[59,9242,2417],{"class":2416},[59,9244,2420],{"class":2413},[59,9246,2423],{"class":2416},[59,9248,2426],{"class":2413},[59,9250,9251,9254,9256,9258,9261,9263],{"class":2401,"line":2845},[59,9252,9253],{"class":2413},"        Y ",[59,9255,815],{"class":2405},[59,9257,5220],{"class":2693},[59,9259,9260],{"class":2413},".addnorm1(X, ",[59,9262,2666],{"class":2693},[59,9264,9265],{"class":2413},".attention(X, X, X, valid_lens))\n",[59,9267,9268,9270,9272,9275,9277],{"class":2401,"line":4781},[59,9269,2465],{"class":2405},[59,9271,5220],{"class":2693},[59,9273,9274],{"class":2413},".addnorm2(Y, ",[59,9276,2666],{"class":2693},[59,9278,9279],{"class":2413},".ffn(Y))\n",[59,9281,9282],{"class":2401,"line":5057},[59,9283,2725],{"emptyLinePlaceholder":2724},[59,9285,9286,9288,9291,9293,9296],{"class":2401,"line":5063},[59,9287,2630],{"class":2405},[59,9289,9290],{"class":2633}," TransformerEncoder",[59,9292,798],{"class":2413},[59,9294,9295],{"class":2639},"Encoder",[59,9297,2426],{"class":2413},[59,9299,9300],{"class":2401,"line":5069},[59,9301,9302],{"class":2432},"    \"\"\"Transformer编码器\"\"\"\n",[59,9304,9305,9307,9309,9311,9313,9315,9318,9320,9322,9324,9326,9328,9330],{"class":2401,"line":5075},[59,9306,2657],{"class":2405},[59,9308,2660],{"class":2455},[59,9310,798],{"class":2413},[59,9312,2666],{"class":2665},[59,9314,2420],{"class":2413},[59,9316,9317],{"class":2416},"vocab_size",[59,9319,2420],{"class":2413},[59,9321,4847],{"class":2416},[59,9323,2420],{"class":2413},[59,9325,4852],{"class":2416},[59,9327,2420],{"class":2413},[59,9329,4857],{"class":2416},[59,9331,3826],{"class":2413},[59,9333,9334,9337,9339,9342,9344,9346,9348,9350],{"class":2401,"line":5096},[59,9335,9336],{"class":2416},"                 num_hiddens",[59,9338,2420],{"class":2413},[59,9340,9341],{"class":2416},"norm_shape",[59,9343,2420],{"class":2413},[59,9345,7675],{"class":2416},[59,9347,2420],{"class":2413},[59,9349,7680],{"class":2416},[59,9351,3826],{"class":2413},[59,9353,9354,9356,9358,9361,9363,9365,9367,9369,9371,9373,9375,9377,9379],{"class":2401,"line":5115},[59,9355,4869],{"class":2416},[59,9357,2420],{"class":2413},[59,9359,9360],{"class":2416},"num_layers",[59,9362,2420],{"class":2413},[59,9364,2671],{"class":2416},[59,9366,2420],{"class":2413},[59,9368,9132],{"class":2416},[59,9370,815],{"class":2405},[59,9372,4883],{"class":2455},[59,9374,2420],{"class":2413},[59,9376,2676],{"class":2405},[59,9378,2679],{"class":2416},[59,9380,2426],{"class":2413},[59,9382,9383,9385,9388,9390,9392,9394,9396,9398],{"class":2401,"line":5134},[59,9384,2687],{"class":2686},[59,9386,9387],{"class":2413},"(TransformerEncoder, ",[59,9389,2666],{"class":2693},[59,9391,2696],{"class":2413},[59,9393,2699],{"class":2455},[59,9395,798],{"class":2413},[59,9397,2676],{"class":2405},[59,9399,2706],{"class":2413},[59,9401,9402,9404,9407,9409],{"class":2401,"line":5139},[59,9403,2711],{"class":2693},[59,9405,9406],{"class":2413},".num_hiddens ",[59,9408,815],{"class":2405},[59,9410,9411],{"class":2413}," num_hiddens\n",[59,9413,9414,9416,9419,9421],{"class":2401,"line":5155},[59,9415,2711],{"class":2693},[59,9417,9418],{"class":2413},".embedding ",[59,9420,815],{"class":2405},[59,9422,9423],{"class":2413}," nn.Embedding(vocab_size, num_hiddens)\n",[59,9425,9426,9428,9431,9433],{"class":2401,"line":5161},[59,9427,2711],{"class":2693},[59,9429,9430],{"class":2413},".pos_encoding ",[59,9432,815],{"class":2405},[59,9434,9435],{"class":2413}," PositionalEncoding(num_hiddens, dropout)\n",[59,9437,9438,9440,9443,9445],{"class":2401,"line":5167},[59,9439,2711],{"class":2693},[59,9441,9442],{"class":2413},".blks ",[59,9444,815],{"class":2405},[59,9446,9447],{"class":2413}," nn.Sequential()\n",[59,9449,9450,9453,9456,9459,9463],{"class":2401,"line":5177},[59,9451,9452],{"class":2405},"        for",[59,9454,9455],{"class":2413}," i ",[59,9457,9458],{"class":2405},"in",[59,9460,9462],{"class":9461},"sDgm9"," range",[59,9464,9465],{"class":2413},"(num_layers):\n",[59,9467,9468,9471,9474,9478,9482,9484,9486,9489],{"class":2401,"line":5201},[59,9469,9470],{"class":2693},"            self",[59,9472,9473],{"class":2413},".blks.add_module(",[59,9475,9477],{"class":9476},"sMWOi","\"",[59,9479,9481],{"class":9480},"sEzAm","block",[59,9483,9477],{"class":9476},[59,9485,5777],{"class":2405},[59,9487,9488],{"class":2686},"str",[59,9490,9491],{"class":2413},"(i),\n",[59,9493,9494],{"class":2401,"line":5206},[59,9495,9496],{"class":2413},"                EncoderBlock(key_size, query_size, value_size, num_hiddens,\n",[59,9498,9499],{"class":2401,"line":5212},[59,9500,9501],{"class":2413},"                             norm_shape, ffn_num_input, ffn_num_hiddens,\n",[59,9503,9504],{"class":2401,"line":5226},[59,9505,9506],{"class":2413},"                             num_heads, dropout, use_bias))\n",[59,9508,9509],{"class":2401,"line":5231},[59,9510,2725],{"emptyLinePlaceholder":2724},[59,9512,9513,9515,9517,9519,9521,9523,9525,9527,9529,9531,9533,9535],{"class":2401,"line":5237},[59,9514,2657],{"class":2405},[59,9516,2752],{"class":2409},[59,9518,798],{"class":2413},[59,9520,2666],{"class":2665},[59,9522,2420],{"class":2413},[59,9524,2417],{"class":2416},[59,9526,2420],{"class":2413},[59,9528,2423],{"class":2416},[59,9530,2420],{"class":2413},[59,9532,8729],{"class":2405},[59,9534,8732],{"class":2416},[59,9536,2426],{"class":2413},[59,9538,9539],{"class":2401,"line":5252},[59,9540,9541],{"class":2439},"        # 因为位置编码值在-1和1之间，\n",[59,9543,9545],{"class":2401,"line":9544},37,[59,9546,9547],{"class":2439},"        # 因此嵌入值乘以嵌入维度的平方根进行缩放，\n",[59,9549,9551],{"class":2401,"line":9550},38,[59,9552,9553],{"class":2439},"        # 然后再与位置编码相加。\n",[59,9555,9557,9559,9561,9563,9566,9568,9571,9573,9576,9578],{"class":2401,"line":9556},39,[59,9558,2569],{"class":2413},[59,9560,815],{"class":2405},[59,9562,5220],{"class":2693},[59,9564,9565],{"class":2413},".pos_encoding(",[59,9567,2666],{"class":2693},[59,9569,9570],{"class":2413},".embedding(X) ",[59,9572,8729],{"class":2405},[59,9574,9575],{"class":2413}," math.sqrt(",[59,9577,2666],{"class":2693},[59,9579,9580],{"class":2413},".num_hiddens))\n",[59,9582,9584,9586,9588,9590,9593,9595,9597,9599,9602,9604,9606],{"class":2401,"line":9583},40,[59,9585,2711],{"class":2693},[59,9587,2837],{"class":2413},[59,9589,815],{"class":2405},[59,9591,9592],{"class":2413}," [",[59,9594,2780],{"class":2455},[59,9596,7200],{"class":2413},[59,9598,8729],{"class":2405},[59,9600,9601],{"class":9461}," len",[59,9603,798],{"class":2413},[59,9605,2666],{"class":2693},[59,9607,9608],{"class":2413},".blks)\n",[59,9610,9612,9614,9617,9619,9622,9624,9626],{"class":2401,"line":9611},41,[59,9613,9452],{"class":2405},[59,9615,9616],{"class":2413}," i, blk ",[59,9618,9458],{"class":2405},[59,9620,9621],{"class":9461}," enumerate",[59,9623,798],{"class":2413},[59,9625,2666],{"class":2693},[59,9627,9628],{"class":2413},".blks):\n",[59,9630,9632,9635,9637],{"class":2401,"line":9631},42,[59,9633,9634],{"class":2413},"            X ",[59,9636,815],{"class":2405},[59,9638,9639],{"class":2413}," blk(X, valid_lens)\n",[59,9641,9643,9645,9648,9650],{"class":2401,"line":9642},43,[59,9644,9470],{"class":2693},[59,9646,9647],{"class":2413},".attention_weights[i] ",[59,9649,815],{"class":2405},[59,9651,9652],{"class":2413}," blk.attention.attention.attention_weights\n",[59,9654,9656,9658],{"class":2401,"line":9655},44,[59,9657,2465],{"class":2405},[59,9659,9660],{"class":2413}," X\n",[11,9662,9664],{"id":9663},"transformer-解码器","transformer 解码器",[15,9666,9667],{},"代码如下：",[2390,9669,9671],{"className":2392,"code":9670,"language":2394,"meta":2395,"style":2395},"class DecoderBlock(nn.Module):\n    \"\"\"解码器中第i个块\"\"\"\n    def __init__(self, key_size, query_size, value_size, num_hiddens,\n                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n                 dropout, i, **kwargs):\n        super(DecoderBlock, self).__init__(**kwargs)\n        self.i = i\n        self.attention1 = d2l.MultiHeadAttention(\n            key_size, query_size, value_size, num_hiddens, num_heads, dropout)\n        self.addnorm1 = AddNorm(norm_shape, dropout)\n        self.attention2 = d2l.MultiHeadAttention(\n            key_size, query_size, value_size, num_hiddens, num_heads, dropout)\n        self.addnorm2 = AddNorm(norm_shape, dropout)\n        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,\n                                   num_hiddens)\n        self.addnorm3 = AddNorm(norm_shape, dropout)\n\n    def forward(self, X, state):\n        enc_outputs, enc_valid_lens = state[0], state[1]\n        # 训练阶段，输出序列的所有词元都在同一时间处理，\n        # 因此state[2][self.i]初始化为None。\n        # 预测阶段，输出序列是通过词元一个接着一个解码的，\n        # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示\n        if state[2][self.i] is None:\n            key_values = X\n        else:\n            key_values = torch.cat((state[2][self.i], X), axis=1)\n        state[2][self.i] = key_values\n        if self.training:\n            batch_size, num_steps, _ = X.shape\n            # dec_valid_lens的开头:(batch_size,num_steps),\n            # 其中每一行是[1,2,...,num_steps]\n            dec_valid_lens = torch.arange(\n                1, num_steps + 1, device=X.device).repeat(batch_size, 1)\n        else:\n            dec_valid_lens = None\n\n        # 自注意力\n        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)\n        Y = self.addnorm1(X, X2)\n        # 编码器－解码器注意力。\n        # enc_outputs的开头:(batch_size,num_steps,num_hiddens)\n        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)\n        Z = self.addnorm2(Y, Y2)\n        return self.addnorm3(Z, self.ffn(Z)), state\n\nclass TransformerDecoder(AttentionDecoder):\n    def __init__(self, vocab_size, key_size, query_size, value_size,\n                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,\n                 num_heads, num_layers, dropout, **kwargs):\n        super(TransformerDecoder, self).__init__(**kwargs)\n        self.num_hiddens = num_hiddens\n        self.num_layers = num_layers\n        self.embedding = nn.Embedding(vocab_size, num_hiddens)\n        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)\n        self.blks = nn.Sequential()\n        for i in range(num_layers):\n            self.blks.add_module(\"block\"+str(i),\n                DecoderBlock(key_size, query_size, value_size, num_hiddens,\n                             norm_shape, ffn_num_input, ffn_num_hiddens,\n                             num_heads, dropout, i))\n        self.dense = nn.Linear(num_hiddens, vocab_size)\n\n    def init_state(self, enc_outputs, enc_valid_lens, *args):\n        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]\n\n    def forward(self, X, state):\n        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]\n        for i, blk in enumerate(self.blks):\n            X, state = blk(X, state)\n            # 解码器自注意力权重\n            self._attention_weights[0][i] = blk.attention1.attention.attention_weights\n            # “编码器－解码器”自注意力权重\n            self._attention_weights[1][i] = blk.attention2.attention.attention_weights\n        return self.dense(X), state\n\n    @property\n    def attention_weights(self):\n        return self._attention_weights\n",[2397,9672,9673,9690,9695,9723,9741,9757,9776,9788,9799,9804,9814,9825,9829,9839,9850,9855,9866,9870,9890,9909,9914,9919,9924,9929,9951,9960,9966,9993,10011,10020,10029,10034,10039,10049,10075,10081,10090,10094,10099,10111,10122,10127,10132,10144,10156,10171,10176,10191,10220,10239,10260,10280,10291,10304,10315,10326,10337,10350,10369,10375,10380,10386,10399,10404,10432,10451,10456,10477,10500,10546,10563,10574,10580,10598,10604,10620,10630,10635,10644,10658],{"__ignoreMap":2395},[59,9674,9675,9677,9680,9682,9684,9686,9688],{"class":2401,"line":2402},[59,9676,2630],{"class":2405},[59,9678,9679],{"class":2633}," DecoderBlock",[59,9681,798],{"class":2413},[59,9683,2640],{"class":2639},[59,9685,864],{"class":2413},[59,9687,2645],{"class":2639},[59,9689,2426],{"class":2413},[59,9691,9692],{"class":2401,"line":2429},[59,9693,9694],{"class":2432},"    \"\"\"解码器中第i个块\"\"\"\n",[59,9696,9697,9699,9701,9703,9705,9707,9709,9711,9713,9715,9717,9719,9721],{"class":2401,"line":2436},[59,9698,2657],{"class":2405},[59,9700,2660],{"class":2455},[59,9702,798],{"class":2413},[59,9704,2666],{"class":2665},[59,9706,2420],{"class":2413},[59,9708,4847],{"class":2416},[59,9710,2420],{"class":2413},[59,9712,4852],{"class":2416},[59,9714,2420],{"class":2413},[59,9716,4857],{"class":2416},[59,9718,2420],{"class":2413},[59,9720,4862],{"class":2416},[59,9722,3826],{"class":2413},[59,9724,9725,9727,9729,9731,9733,9735,9737,9739],{"class":2401,"line":2443},[59,9726,9108],{"class":2416},[59,9728,2420],{"class":2413},[59,9730,7675],{"class":2416},[59,9732,2420],{"class":2413},[59,9734,7680],{"class":2416},[59,9736,2420],{"class":2413},[59,9738,4595],{"class":2416},[59,9740,3826],{"class":2413},[59,9742,9743,9745,9747,9749,9751,9753,9755],{"class":2401,"line":2462},[59,9744,9127],{"class":2416},[59,9746,2420],{"class":2413},[59,9748,789],{"class":2416},[59,9750,2420],{"class":2413},[59,9752,2676],{"class":2405},[59,9754,2679],{"class":2416},[59,9756,2426],{"class":2413},[59,9758,9759,9761,9764,9766,9768,9770,9772,9774],{"class":2401,"line":2483},[59,9760,2687],{"class":2686},[59,9762,9763],{"class":2413},"(DecoderBlock, ",[59,9765,2666],{"class":2693},[59,9767,2696],{"class":2413},[59,9769,2699],{"class":2455},[59,9771,798],{"class":2413},[59,9773,2676],{"class":2405},[59,9775,2706],{"class":2413},[59,9777,9778,9780,9783,9785],{"class":2401,"line":2491},[59,9779,2711],{"class":2693},[59,9781,9782],{"class":2413},".i ",[59,9784,815],{"class":2405},[59,9786,9787],{"class":2413}," i\n",[59,9789,9790,9792,9795,9797],{"class":2401,"line":2502},[59,9791,2711],{"class":2693},[59,9793,9794],{"class":2413},".attention1 ",[59,9796,815],{"class":2405},[59,9798,9174],{"class":2413},[59,9800,9801],{"class":2401,"line":2519},[59,9802,9803],{"class":2413},"            key_size, query_size, value_size, num_hiddens, num_heads, dropout)\n",[59,9805,9806,9808,9810,9812],{"class":2401,"line":2535},[59,9807,2711],{"class":2693},[59,9809,9191],{"class":2413},[59,9811,815],{"class":2405},[59,9813,9196],{"class":2413},[59,9815,9816,9818,9821,9823],{"class":2401,"line":2543},[59,9817,2711],{"class":2693},[59,9819,9820],{"class":2413},".attention2 ",[59,9822,815],{"class":2405},[59,9824,9174],{"class":2413},[59,9826,9827],{"class":2401,"line":2560},[59,9828,9803],{"class":2413},[59,9830,9831,9833,9835,9837],{"class":2401,"line":2566},[59,9832,2711],{"class":2693},[59,9834,9220],{"class":2413},[59,9836,815],{"class":2405},[59,9838,9196],{"class":2413},[59,9840,9841,9843,9845,9847],{"class":2401,"line":2591},[59,9842,2711],{"class":2693},[59,9844,9203],{"class":2413},[59,9846,815],{"class":2405},[59,9848,9849],{"class":2413}," PositionWiseFFN(ffn_num_input, ffn_num_hiddens,\n",[59,9851,9852],{"class":2401,"line":2604},[59,9853,9854],{"class":2413},"                                   num_hiddens)\n",[59,9856,9857,9859,9862,9864],{"class":2401,"line":2845},[59,9858,2711],{"class":2693},[59,9860,9861],{"class":2413},".addnorm3 ",[59,9863,815],{"class":2405},[59,9865,9196],{"class":2413},[59,9867,9868],{"class":2401,"line":4781},[59,9869,2725],{"emptyLinePlaceholder":2724},[59,9871,9872,9874,9876,9878,9880,9882,9884,9886,9888],{"class":2401,"line":5057},[59,9873,2657],{"class":2405},[59,9875,2752],{"class":2409},[59,9877,798],{"class":2413},[59,9879,2666],{"class":2665},[59,9881,2420],{"class":2413},[59,9883,2417],{"class":2416},[59,9885,2420],{"class":2413},[59,9887,8862],{"class":2416},[59,9889,2426],{"class":2413},[59,9891,9892,9895,9897,9900,9902,9905,9907],{"class":2401,"line":5063},[59,9893,9894],{"class":2413},"        enc_outputs, enc_valid_lens ",[59,9896,815],{"class":2405},[59,9898,9899],{"class":2413}," state[",[59,9901,1754],{"class":2455},[59,9903,9904],{"class":2413},"], state[",[59,9906,83],{"class":2455},[59,9908,2799],{"class":2413},[59,9910,9911],{"class":2401,"line":5069},[59,9912,9913],{"class":2439},"        # 训练阶段，输出序列的所有词元都在同一时间处理，\n",[59,9915,9916],{"class":2401,"line":5075},[59,9917,9918],{"class":2439},"        # 因此state[2][self.i]初始化为None。\n",[59,9920,9921],{"class":2401,"line":5096},[59,9922,9923],{"class":2439},"        # 预测阶段，输出序列是通过词元一个接着一个解码的，\n",[59,9925,9926],{"class":2401,"line":5115},[59,9927,9928],{"class":2439},"        # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示\n",[59,9930,9931,9933,9935,9937,9940,9942,9945,9947,9949],{"class":2401,"line":5134},[59,9932,2505],{"class":2405},[59,9934,9899],{"class":2413},[59,9936,2821],{"class":2455},[59,9938,9939],{"class":2413},"][",[59,9941,2666],{"class":2693},[59,9943,9944],{"class":2413},".i] ",[59,9946,2452],{"class":2405},[59,9948,2456],{"class":2455},[59,9950,2459],{"class":2413},[59,9952,9953,9956,9958],{"class":2401,"line":5139},[59,9954,9955],{"class":2413},"            key_values ",[59,9957,815],{"class":2405},[59,9959,9660],{"class":2413},[59,9961,9962,9964],{"class":2401,"line":5155},[59,9963,2538],{"class":2405},[59,9965,2459],{"class":2413},[59,9967,9968,9970,9972,9975,9977,9979,9981,9984,9987,9989,9991],{"class":2401,"line":5161},[59,9969,9955],{"class":2413},[59,9971,815],{"class":2405},[59,9973,9974],{"class":2413}," torch.cat((state[",[59,9976,2821],{"class":2455},[59,9978,9939],{"class":2413},[59,9980,2666],{"class":2693},[59,9982,9983],{"class":2413},".i], X), ",[59,9985,9986],{"class":2471},"axis",[59,9988,815],{"class":2405},[59,9990,83],{"class":2455},[59,9992,2480],{"class":2413},[59,9994,9995,9998,10000,10002,10004,10006,10008],{"class":2401,"line":5167},[59,9996,9997],{"class":2413},"        state[",[59,9999,2821],{"class":2455},[59,10001,9939],{"class":2413},[59,10003,2666],{"class":2693},[59,10005,9944],{"class":2413},[59,10007,815],{"class":2405},[59,10009,10010],{"class":2413}," key_values\n",[59,10012,10013,10015,10017],{"class":2401,"line":5177},[59,10014,2505],{"class":2405},[59,10016,5220],{"class":2693},[59,10018,10019],{"class":2413},".training:\n",[59,10021,10022,10025,10027],{"class":2401,"line":5201},[59,10023,10024],{"class":2413},"            batch_size, num_steps, _ ",[59,10026,815],{"class":2405},[59,10028,2499],{"class":2413},[59,10030,10031],{"class":2401,"line":5206},[59,10032,10033],{"class":2439},"            # dec_valid_lens的开头:(batch_size,num_steps),\n",[59,10035,10036],{"class":2401,"line":5212},[59,10037,10038],{"class":2439},"            # 其中每一行是[1,2,...,num_steps]\n",[59,10040,10041,10044,10046],{"class":2401,"line":5226},[59,10042,10043],{"class":2413},"            dec_valid_lens ",[59,10045,815],{"class":2405},[59,10047,10048],{"class":2413}," torch.arange(\n",[59,10050,10051,10054,10057,10059,10061,10063,10066,10068,10071,10073],{"class":2401,"line":5231},[59,10052,10053],{"class":2455},"                1",[59,10055,10056],{"class":2413},", num_steps ",[59,10058,5777],{"class":2405},[59,10060,2514],{"class":2455},[59,10062,2420],{"class":2413},[59,10064,10065],{"class":2471},"device",[59,10067,815],{"class":2405},[59,10069,10070],{"class":2413},"X.device).repeat(batch_size, ",[59,10072,83],{"class":2455},[59,10074,2480],{"class":2413},[59,10076,10077,10079],{"class":2401,"line":5237},[59,10078,2538],{"class":2405},[59,10080,2459],{"class":2413},[59,10082,10083,10085,10087],{"class":2401,"line":5252},[59,10084,10043],{"class":2413},[59,10086,815],{"class":2405},[59,10088,10089],{"class":2455}," None\n",[59,10091,10092],{"class":2401,"line":9544},[59,10093,2725],{"emptyLinePlaceholder":2724},[59,10095,10096],{"class":2401,"line":9550},[59,10097,10098],{"class":2439},"        # 自注意力\n",[59,10100,10101,10104,10106,10108],{"class":2401,"line":9556},[59,10102,10103],{"class":2413},"        X2 ",[59,10105,815],{"class":2405},[59,10107,5220],{"class":2693},[59,10109,10110],{"class":2413},".attention1(X, key_values, key_values, dec_valid_lens)\n",[59,10112,10113,10115,10117,10119],{"class":2401,"line":9583},[59,10114,9253],{"class":2413},[59,10116,815],{"class":2405},[59,10118,5220],{"class":2693},[59,10120,10121],{"class":2413},".addnorm1(X, X2)\n",[59,10123,10124],{"class":2401,"line":9611},[59,10125,10126],{"class":2439},"        # 编码器－解码器注意力。\n",[59,10128,10129],{"class":2401,"line":9631},[59,10130,10131],{"class":2439},"        # enc_outputs的开头:(batch_size,num_steps,num_hiddens)\n",[59,10133,10134,10137,10139,10141],{"class":2401,"line":9642},[59,10135,10136],{"class":2413},"        Y2 ",[59,10138,815],{"class":2405},[59,10140,5220],{"class":2693},[59,10142,10143],{"class":2413},".attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)\n",[59,10145,10146,10149,10151,10153],{"class":2401,"line":9655},[59,10147,10148],{"class":2413},"        Z ",[59,10150,815],{"class":2405},[59,10152,5220],{"class":2693},[59,10154,10155],{"class":2413},".addnorm2(Y, Y2)\n",[59,10157,10159,10161,10163,10166,10168],{"class":2401,"line":10158},45,[59,10160,2465],{"class":2405},[59,10162,5220],{"class":2693},[59,10164,10165],{"class":2413},".addnorm3(Z, ",[59,10167,2666],{"class":2693},[59,10169,10170],{"class":2413},".ffn(Z)), state\n",[59,10172,10174],{"class":2401,"line":10173},46,[59,10175,2725],{"emptyLinePlaceholder":2724},[59,10177,10179,10181,10184,10186,10189],{"class":2401,"line":10178},47,[59,10180,2630],{"class":2405},[59,10182,10183],{"class":2633}," TransformerDecoder",[59,10185,798],{"class":2413},[59,10187,10188],{"class":2639},"AttentionDecoder",[59,10190,2426],{"class":2413},[59,10192,10194,10196,10198,10200,10202,10204,10206,10208,10210,10212,10214,10216,10218],{"class":2401,"line":10193},48,[59,10195,2657],{"class":2405},[59,10197,2660],{"class":2455},[59,10199,798],{"class":2413},[59,10201,2666],{"class":2665},[59,10203,2420],{"class":2413},[59,10205,9317],{"class":2416},[59,10207,2420],{"class":2413},[59,10209,4847],{"class":2416},[59,10211,2420],{"class":2413},[59,10213,4852],{"class":2416},[59,10215,2420],{"class":2413},[59,10217,4857],{"class":2416},[59,10219,3826],{"class":2413},[59,10221,10223,10225,10227,10229,10231,10233,10235,10237],{"class":2401,"line":10222},49,[59,10224,9336],{"class":2416},[59,10226,2420],{"class":2413},[59,10228,9341],{"class":2416},[59,10230,2420],{"class":2413},[59,10232,7675],{"class":2416},[59,10234,2420],{"class":2413},[59,10236,7680],{"class":2416},[59,10238,3826],{"class":2413},[59,10240,10242,10244,10246,10248,10250,10252,10254,10256,10258],{"class":2401,"line":10241},50,[59,10243,4869],{"class":2416},[59,10245,2420],{"class":2413},[59,10247,9360],{"class":2416},[59,10249,2420],{"class":2413},[59,10251,2671],{"class":2416},[59,10253,2420],{"class":2413},[59,10255,2676],{"class":2405},[59,10257,2679],{"class":2416},[59,10259,2426],{"class":2413},[59,10261,10263,10265,10268,10270,10272,10274,10276,10278],{"class":2401,"line":10262},51,[59,10264,2687],{"class":2686},[59,10266,10267],{"class":2413},"(TransformerDecoder, ",[59,10269,2666],{"class":2693},[59,10271,2696],{"class":2413},[59,10273,2699],{"class":2455},[59,10275,798],{"class":2413},[59,10277,2676],{"class":2405},[59,10279,2706],{"class":2413},[59,10281,10283,10285,10287,10289],{"class":2401,"line":10282},52,[59,10284,2711],{"class":2693},[59,10286,9406],{"class":2413},[59,10288,815],{"class":2405},[59,10290,9411],{"class":2413},[59,10292,10294,10296,10299,10301],{"class":2401,"line":10293},53,[59,10295,2711],{"class":2693},[59,10297,10298],{"class":2413},".num_layers ",[59,10300,815],{"class":2405},[59,10302,10303],{"class":2413}," num_layers\n",[59,10305,10307,10309,10311,10313],{"class":2401,"line":10306},54,[59,10308,2711],{"class":2693},[59,10310,9418],{"class":2413},[59,10312,815],{"class":2405},[59,10314,9423],{"class":2413},[59,10316,10318,10320,10322,10324],{"class":2401,"line":10317},55,[59,10319,2711],{"class":2693},[59,10321,9430],{"class":2413},[59,10323,815],{"class":2405},[59,10325,9435],{"class":2413},[59,10327,10329,10331,10333,10335],{"class":2401,"line":10328},56,[59,10330,2711],{"class":2693},[59,10332,9442],{"class":2413},[59,10334,815],{"class":2405},[59,10336,9447],{"class":2413},[59,10338,10340,10342,10344,10346,10348],{"class":2401,"line":10339},57,[59,10341,9452],{"class":2405},[59,10343,9455],{"class":2413},[59,10345,9458],{"class":2405},[59,10347,9462],{"class":9461},[59,10349,9465],{"class":2413},[59,10351,10353,10355,10357,10359,10361,10363,10365,10367],{"class":2401,"line":10352},58,[59,10354,9470],{"class":2693},[59,10356,9473],{"class":2413},[59,10358,9477],{"class":9476},[59,10360,9481],{"class":9480},[59,10362,9477],{"class":9476},[59,10364,5777],{"class":2405},[59,10366,9488],{"class":2686},[59,10368,9491],{"class":2413},[59,10370,10372],{"class":2401,"line":10371},59,[59,10373,10374],{"class":2413},"                DecoderBlock(key_size, query_size, value_size, num_hiddens,\n",[59,10376,10378],{"class":2401,"line":10377},60,[59,10379,9501],{"class":2413},[59,10381,10383],{"class":2401,"line":10382},61,[59,10384,10385],{"class":2413},"                             num_heads, dropout, i))\n",[59,10387,10389,10391,10394,10396],{"class":2401,"line":10388},62,[59,10390,2711],{"class":2693},[59,10392,10393],{"class":2413},".dense ",[59,10395,815],{"class":2405},[59,10397,10398],{"class":2413}," nn.Linear(num_hiddens, vocab_size)\n",[59,10400,10402],{"class":2401,"line":10401},63,[59,10403,2725],{"emptyLinePlaceholder":2724},[59,10405,10407,10409,10411,10413,10415,10417,10419,10421,10424,10426,10428,10430],{"class":2401,"line":10406},64,[59,10408,2657],{"class":2405},[59,10410,8816],{"class":2409},[59,10412,798],{"class":2413},[59,10414,2666],{"class":2665},[59,10416,2420],{"class":2413},[59,10418,8825],{"class":2416},[59,10420,2420],{"class":2413},[59,10422,10423],{"class":2416},"enc_valid_lens",[59,10425,2420],{"class":2413},[59,10427,8729],{"class":2405},[59,10429,8732],{"class":2416},[59,10431,2426],{"class":2413},[59,10433,10435,10437,10440,10442,10444,10446,10448],{"class":2401,"line":10434},65,[59,10436,2465],{"class":2405},[59,10438,10439],{"class":2413}," [enc_outputs, enc_valid_lens, [",[59,10441,2780],{"class":2455},[59,10443,7200],{"class":2413},[59,10445,8729],{"class":2405},[59,10447,5220],{"class":2693},[59,10449,10450],{"class":2413},".num_layers]\n",[59,10452,10454],{"class":2401,"line":10453},66,[59,10455,2725],{"emptyLinePlaceholder":2724},[59,10457,10459,10461,10463,10465,10467,10469,10471,10473,10475],{"class":2401,"line":10458},67,[59,10460,2657],{"class":2405},[59,10462,2752],{"class":2409},[59,10464,798],{"class":2413},[59,10466,2666],{"class":2665},[59,10468,2420],{"class":2413},[59,10470,2417],{"class":2416},[59,10472,2420],{"class":2413},[59,10474,8862],{"class":2416},[59,10476,2426],{"class":2413},[59,10478,10480,10482,10484,10486,10488,10490,10492,10494,10496,10498],{"class":2401,"line":10479},68,[59,10481,2569],{"class":2413},[59,10483,815],{"class":2405},[59,10485,5220],{"class":2693},[59,10487,9565],{"class":2413},[59,10489,2666],{"class":2693},[59,10491,9570],{"class":2413},[59,10493,8729],{"class":2405},[59,10495,9575],{"class":2413},[59,10497,2666],{"class":2693},[59,10499,9580],{"class":2413},[59,10501,10503,10505,10508,10510,10513,10515,10517,10519,10521,10523,10525,10528,10531,10534,10536,10538,10541,10543],{"class":2401,"line":10502},69,[59,10504,2711],{"class":2693},[59,10506,10507],{"class":2413},"._attention_weights ",[59,10509,815],{"class":2405},[59,10511,10512],{"class":2413}," [[",[59,10514,2780],{"class":2455},[59,10516,7200],{"class":2413},[59,10518,8729],{"class":2405},[59,10520,9601],{"class":9461},[59,10522,798],{"class":2413},[59,10524,2666],{"class":2693},[59,10526,10527],{"class":2413},".blks) ",[59,10529,10530],{"class":2405},"for",[59,10532,10533],{"class":2413}," _ ",[59,10535,9458],{"class":2405},[59,10537,9462],{"class":9461},[59,10539,10540],{"class":2413}," (",[59,10542,2821],{"class":2455},[59,10544,10545],{"class":2413},")]\n",[59,10547,10549,10551,10553,10555,10557,10559,10561],{"class":2401,"line":10548},70,[59,10550,9452],{"class":2405},[59,10552,9616],{"class":2413},[59,10554,9458],{"class":2405},[59,10556,9621],{"class":9461},[59,10558,798],{"class":2413},[59,10560,2666],{"class":2693},[59,10562,9628],{"class":2413},[59,10564,10566,10569,10571],{"class":2401,"line":10565},71,[59,10567,10568],{"class":2413},"            X, state ",[59,10570,815],{"class":2405},[59,10572,10573],{"class":2413}," blk(X, state)\n",[59,10575,10577],{"class":2401,"line":10576},72,[59,10578,10579],{"class":2439},"            # 解码器自注意力权重\n",[59,10581,10583,10585,10588,10590,10593,10595],{"class":2401,"line":10582},73,[59,10584,9470],{"class":2693},[59,10586,10587],{"class":2413},"._attention_weights[",[59,10589,1754],{"class":2455},[59,10591,10592],{"class":2413},"][i] ",[59,10594,815],{"class":2405},[59,10596,10597],{"class":2413}," blk.attention1.attention.attention_weights\n",[59,10599,10601],{"class":2401,"line":10600},74,[59,10602,10603],{"class":2439},"            # “编码器－解码器”自注意力权重\n",[59,10605,10607,10609,10611,10613,10615,10617],{"class":2401,"line":10606},75,[59,10608,9470],{"class":2693},[59,10610,10587],{"class":2413},[59,10612,83],{"class":2455},[59,10614,10592],{"class":2413},[59,10616,815],{"class":2405},[59,10618,10619],{"class":2413}," blk.attention2.attention.attention_weights\n",[59,10621,10623,10625,10627],{"class":2401,"line":10622},76,[59,10624,2465],{"class":2405},[59,10626,5220],{"class":2693},[59,10628,10629],{"class":2413},".dense(X), state\n",[59,10631,10633],{"class":2401,"line":10632},77,[59,10634,2725],{"emptyLinePlaceholder":2724},[59,10636,10638,10641],{"class":2401,"line":10637},78,[59,10639,10640],{"class":2409},"    @",[59,10642,10643],{"class":2686},"property\n",[59,10645,10647,10649,10652,10654,10656],{"class":2401,"line":10646},79,[59,10648,2657],{"class":2405},[59,10650,10651],{"class":2409}," attention_weights",[59,10653,798],{"class":2413},[59,10655,2666],{"class":2665},[59,10657,2426],{"class":2413},[59,10659,10661,10663,10665],{"class":2401,"line":10660},80,[59,10662,2465],{"class":2405},[59,10664,5220],{"class":2693},[59,10666,10667],{"class":2413},"._attention_weights\n",[11,10669,10670],{"id":10670},"训练",[15,10672,10673],{},"同样，训练时需要屏蔽部分不需要的内容：",[2390,10675,10677],{"className":2392,"code":10676,"language":2394,"meta":2395,"style":2395},"def sequence_mask(X, valid_len, value=0):\n    \"\"\"在序列中屏蔽不相关的项\"\"\"\n    maxlen = X.size(1)\n    mask = torch.arange((maxlen), dtype=torch.float32,\n                        device=X.device)[None, :] \u003C valid_len[:, None]\n    X[~mask] = value\n    return X\n\nclass MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n    \"\"\"带遮蔽的softmax交叉熵损失函数\"\"\"\n    # pred的形状：(batch_size,num_steps,vocab_size)\n    # label的形状：(batch_size,num_steps)\n    # valid_len的形状：(batch_size,)\n    def forward(self, pred, label, valid_len):\n        weights = torch.ones_like(label)\n        weights = sequence_mask(weights, valid_len)\n        self.reduction='none'\n        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n            pred.permute(0, 2, 1), label)\n        weighted_loss = (unweighted_loss * weights).mean(dim=1)\n        return weighted_loss\n",[2397,10678,10679,10706,10711,10725,10743,10772,10788,10794,10798,10816,10821,10826,10831,10836,10862,10872,10881,10899,10917,10935,10958],{"__ignoreMap":2395},[59,10680,10681,10683,10686,10688,10690,10692,10695,10697,10700,10702,10704],{"class":2401,"line":2402},[59,10682,2406],{"class":2405},[59,10684,10685],{"class":2409}," sequence_mask",[59,10687,798],{"class":2413},[59,10689,2417],{"class":2416},[59,10691,2420],{"class":2413},[59,10693,10694],{"class":2416},"valid_len",[59,10696,2420],{"class":2413},[59,10698,10699],{"class":2416},"value",[59,10701,815],{"class":2405},[59,10703,1754],{"class":2455},[59,10705,2426],{"class":2413},[59,10707,10708],{"class":2401,"line":2429},[59,10709,10710],{"class":2432},"    \"\"\"在序列中屏蔽不相关的项\"\"\"\n",[59,10712,10713,10716,10718,10721,10723],{"class":2401,"line":2436},[59,10714,10715],{"class":2413},"    maxlen ",[59,10717,815],{"class":2405},[59,10719,10720],{"class":2413}," X.size(",[59,10722,83],{"class":2455},[59,10724,2480],{"class":2413},[59,10726,10727,10730,10732,10735,10738,10740],{"class":2401,"line":2443},[59,10728,10729],{"class":2413},"    mask ",[59,10731,815],{"class":2405},[59,10733,10734],{"class":2413}," torch.arange((maxlen), ",[59,10736,10737],{"class":2471},"dtype",[59,10739,815],{"class":2405},[59,10741,10742],{"class":2413},"torch.float32,\n",[59,10744,10745,10748,10750,10753,10755,10758,10761,10764,10766,10768,10770],{"class":2401,"line":2462},[59,10746,10747],{"class":2471},"                        device",[59,10749,815],{"class":2405},[59,10751,10752],{"class":2413},"X.device)[",[59,10754,2780],{"class":2455},[59,10756,10757],{"class":2413},", :] ",[59,10759,10760],{"class":2405},"\u003C",[59,10762,10763],{"class":2413}," valid_len[",[59,10765,7184],{"class":7183},[59,10767,2420],{"class":2413},[59,10769,2780],{"class":2455},[59,10771,2799],{"class":2413},[59,10773,10774,10777,10780,10783,10785],{"class":2401,"line":2483},[59,10775,10776],{"class":2413},"    X[",[59,10778,10779],{"class":2405},"~",[59,10781,10782],{"class":2413},"mask] ",[59,10784,815],{"class":2405},[59,10786,10787],{"class":2413}," value\n",[59,10789,10790,10792],{"class":2401,"line":2491},[59,10791,4687],{"class":2405},[59,10793,9660],{"class":2413},[59,10795,10796],{"class":2401,"line":2502},[59,10797,2725],{"emptyLinePlaceholder":2724},[59,10799,10800,10802,10805,10807,10809,10811,10814],{"class":2401,"line":2519},[59,10801,2630],{"class":2405},[59,10803,10804],{"class":2633}," MaskedSoftmaxCELoss",[59,10806,798],{"class":2413},[59,10808,2640],{"class":2639},[59,10810,864],{"class":2413},[59,10812,10813],{"class":2639},"CrossEntropyLoss",[59,10815,2426],{"class":2413},[59,10817,10818],{"class":2401,"line":2535},[59,10819,10820],{"class":2432},"    \"\"\"带遮蔽的softmax交叉熵损失函数\"\"\"\n",[59,10822,10823],{"class":2401,"line":2543},[59,10824,10825],{"class":2439},"    # pred的形状：(batch_size,num_steps,vocab_size)\n",[59,10827,10828],{"class":2401,"line":2560},[59,10829,10830],{"class":2439},"    # label的形状：(batch_size,num_steps)\n",[59,10832,10833],{"class":2401,"line":2566},[59,10834,10835],{"class":2439},"    # valid_len的形状：(batch_size,)\n",[59,10837,10838,10840,10842,10844,10846,10848,10851,10853,10856,10858,10860],{"class":2401,"line":2591},[59,10839,2657],{"class":2405},[59,10841,2752],{"class":2409},[59,10843,798],{"class":2413},[59,10845,2666],{"class":2665},[59,10847,2420],{"class":2413},[59,10849,10850],{"class":2416},"pred",[59,10852,2420],{"class":2413},[59,10854,10855],{"class":2416},"label",[59,10857,2420],{"class":2413},[59,10859,10694],{"class":2416},[59,10861,2426],{"class":2413},[59,10863,10864,10867,10869],{"class":2401,"line":2604},[59,10865,10866],{"class":2413},"        weights ",[59,10868,815],{"class":2405},[59,10870,10871],{"class":2413}," torch.ones_like(label)\n",[59,10873,10874,10876,10878],{"class":2401,"line":2845},[59,10875,10866],{"class":2413},[59,10877,815],{"class":2405},[59,10879,10880],{"class":2413}," sequence_mask(weights, valid_len)\n",[59,10882,10883,10885,10888,10890,10893,10896],{"class":2401,"line":4781},[59,10884,2711],{"class":2693},[59,10886,10887],{"class":2413},".reduction",[59,10889,815],{"class":2405},[59,10891,10892],{"class":9476},"'",[59,10894,10895],{"class":9480},"none",[59,10897,10898],{"class":9476},"'\n",[59,10900,10901,10904,10906,10909,10912,10914],{"class":2401,"line":5057},[59,10902,10903],{"class":2413},"        unweighted_loss ",[59,10905,815],{"class":2405},[59,10907,10908],{"class":2686}," super",[59,10910,10911],{"class":2413},"(MaskedSoftmaxCELoss, ",[59,10913,2666],{"class":2693},[59,10915,10916],{"class":2413},").forward(\n",[59,10918,10919,10922,10924,10926,10928,10930,10932],{"class":2401,"line":5063},[59,10920,10921],{"class":2413},"            pred.permute(",[59,10923,1754],{"class":2455},[59,10925,2420],{"class":2413},[59,10927,2821],{"class":2455},[59,10929,2420],{"class":2413},[59,10931,83],{"class":2455},[59,10933,10934],{"class":2413},"), label)\n",[59,10936,10937,10940,10942,10945,10947,10950,10952,10954,10956],{"class":2401,"line":5069},[59,10938,10939],{"class":2413},"        weighted_loss ",[59,10941,815],{"class":2405},[59,10943,10944],{"class":2413}," (unweighted_loss ",[59,10946,8729],{"class":2405},[59,10948,10949],{"class":2413}," weights).mean(",[59,10951,2472],{"class":2471},[59,10953,815],{"class":2405},[59,10955,83],{"class":2455},[59,10957,2480],{"class":2413},[59,10959,10960,10962],{"class":2401,"line":5075},[59,10961,2465],{"class":2405},[59,10963,10964],{"class":2413}," weighted_loss\n",[15,10966,10967],{},"梯度裁剪函数：",[2390,10969,10971],{"className":2392,"code":10970,"language":2394,"meta":2395,"style":2395},"def grad_clipping(net, theta):\n    \"\"\"裁剪梯度\"\"\"\n    if isinstance(net, nn.Module):\n        params = [p for p in net.parameters() if p.requires_grad]\n    else:\n        params = net.params\n    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n    if norm > theta:\n        for param in params:\n            param.grad[:] *= theta / norm\n",[2397,10972,10973,10992,10997,11007,11033,11039,11048,11080,11093,11105],{"__ignoreMap":2395},[59,10974,10975,10977,10980,10982,10985,10987,10990],{"class":2401,"line":2402},[59,10976,2406],{"class":2405},[59,10978,10979],{"class":2409}," grad_clipping",[59,10981,798],{"class":2413},[59,10983,10984],{"class":2416},"net",[59,10986,2420],{"class":2413},[59,10988,10989],{"class":2416},"theta",[59,10991,2426],{"class":2413},[59,10993,10994],{"class":2401,"line":2429},[59,10995,10996],{"class":2432},"    \"\"\"裁剪梯度\"\"\"\n",[59,10998,10999,11001,11004],{"class":2401,"line":2436},[59,11000,2446],{"class":2405},[59,11002,11003],{"class":9461}," isinstance",[59,11005,11006],{"class":2413},"(net, nn.Module):\n",[59,11008,11009,11012,11014,11017,11019,11022,11024,11027,11030],{"class":2401,"line":2443},[59,11010,11011],{"class":2413},"        params ",[59,11013,815],{"class":2405},[59,11015,11016],{"class":2413}," [p ",[59,11018,10530],{"class":2405},[59,11020,11021],{"class":2413}," p ",[59,11023,9458],{"class":2405},[59,11025,11026],{"class":2413}," net.parameters() ",[59,11028,11029],{"class":2405},"if",[59,11031,11032],{"class":2413}," p.requires_grad]\n",[59,11034,11035,11037],{"class":2401,"line":2462},[59,11036,2486],{"class":2405},[59,11038,2459],{"class":2413},[59,11040,11041,11043,11045],{"class":2401,"line":2483},[59,11042,11011],{"class":2413},[59,11044,815],{"class":2405},[59,11046,11047],{"class":2413}," net.params\n",[59,11049,11050,11053,11055,11058,11061,11064,11066,11069,11071,11073,11075,11077],{"class":2401,"line":2491},[59,11051,11052],{"class":2413},"    norm ",[59,11054,815],{"class":2405},[59,11056,11057],{"class":2413}," torch.sqrt(",[59,11059,11060],{"class":9461},"sum",[59,11062,11063],{"class":2413},"(torch.sum((p.grad ",[59,11065,2676],{"class":2405},[59,11067,11068],{"class":2455}," 2",[59,11070,2824],{"class":2413},[59,11072,10530],{"class":2405},[59,11074,11021],{"class":2413},[59,11076,9458],{"class":2405},[59,11078,11079],{"class":2413}," params))\n",[59,11081,11082,11084,11087,11090],{"class":2401,"line":2502},[59,11083,2446],{"class":2405},[59,11085,11086],{"class":2413}," norm ",[59,11088,11089],{"class":2405},">",[59,11091,11092],{"class":2413}," theta:\n",[59,11094,11095,11097,11100,11102],{"class":2401,"line":2519},[59,11096,9452],{"class":2405},[59,11098,11099],{"class":2413}," param ",[59,11101,9458],{"class":2405},[59,11103,11104],{"class":2413}," params:\n",[59,11106,11107,11110,11112,11114,11117,11120,11122],{"class":2401,"line":2535},[59,11108,11109],{"class":2413},"            param.grad[",[59,11111,7184],{"class":7183},[59,11113,7200],{"class":2413},[59,11115,11116],{"class":2405},"*=",[59,11118,11119],{"class":2413}," theta ",[59,11121,2827],{"class":2405},[59,11123,11124],{"class":2413}," norm\n",[15,11126,11127,11128],{},"之后，在训练时，特定的序列开始词元（“",[11129,11130,11131,11132],"bos",{},"”）和 原始的输出序列（不包括序列结束词元“",[11133,11134,11135],"eos",{},"”） 拼接在一起作为解码器的输入。 这被称为强制教学（teacher forcing）， 因为原始的输出序列（词元的标签）被送入解码器。 或者，将来自上一个时间步的预测得到的词元作为解码器的当前输入。",[2390,11137,11139],{"className":2392,"code":11138,"language":2394,"meta":2395,"style":2395},"def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):\n    \"\"\"训练序列到序列模型\"\"\"\n    def xavier_init_weights(m):\n        if type(m) == nn.Linear:\n            nn.init.xavier_uniform_(m.weight)\n        if type(m) == nn.GRU:\n            for param in m._flat_weights_names:\n                if \"weight\" in param:\n                    nn.init.xavier_uniform_(m._parameters[param])\n\n    net.apply(xavier_init_weights)\n    net.to(device)\n    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n    loss = MaskedSoftmaxCELoss()\n    net.train()\n    animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n                     xlim=[10, num_epochs])\n    for epoch in range(num_epochs):\n        timer = d2l.Timer()\n        metric = d2l.Accumulator(2)  # 训练损失总和，词元数量\n        for batch in data_iter:\n            optimizer.zero_grad()\n            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n            bos = torch.tensor([tgt_vocab['\u003Cbos>']] * Y.shape[0],\n                          device=device).reshape(-1, 1)\n            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学\n            Y_hat, _ = net(X, dec_input, X_valid_len)\n            l = loss(Y_hat, Y, Y_valid_len)\n            l.sum().backward()      # 损失函数的标量进行“反向传播”\n            grad_clipping(net, 1)\n            num_tokens = Y_valid_len.sum()\n            optimizer.step()\n            with torch.no_grad():\n                metric.add(l.sum(), num_tokens)\n        if (epoch + 1) % 10 == 0:\n            animator.add(epoch + 1, (metric[0] / metric[1],))\n    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '\n        f'tokens/sec on {str(device)}')\n",[2397,11140,11141,11178,11183,11196,11211,11216,11234,11246,11265,11270,11274,11279,11284,11301,11311,11316,11354,11370,11385,11395,11413,11425,11430,11450,11480,11500,11530,11540,11550,11558,11567,11577,11582,11590,11595,11622,11648,11709],{"__ignoreMap":2395},[59,11142,11143,11145,11148,11150,11152,11154,11157,11159,11162,11164,11167,11169,11172,11174,11176],{"class":2401,"line":2402},[59,11144,2406],{"class":2405},[59,11146,11147],{"class":2409}," train_seq2seq",[59,11149,798],{"class":2413},[59,11151,10984],{"class":2416},[59,11153,2420],{"class":2413},[59,11155,11156],{"class":2416},"data_iter",[59,11158,2420],{"class":2413},[59,11160,11161],{"class":2416},"lr",[59,11163,2420],{"class":2413},[59,11165,11166],{"class":2416},"num_epochs",[59,11168,2420],{"class":2413},[59,11170,11171],{"class":2416},"tgt_vocab",[59,11173,2420],{"class":2413},[59,11175,10065],{"class":2416},[59,11177,2426],{"class":2413},[59,11179,11180],{"class":2401,"line":2429},[59,11181,11182],{"class":2432},"    \"\"\"训练序列到序列模型\"\"\"\n",[59,11184,11185,11187,11190,11192,11194],{"class":2401,"line":2436},[59,11186,2657],{"class":2405},[59,11188,11189],{"class":2409}," xavier_init_weights",[59,11191,798],{"class":2413},[59,11193,366],{"class":2416},[59,11195,2426],{"class":2413},[59,11197,11198,11200,11203,11206,11208],{"class":2401,"line":2443},[59,11199,2505],{"class":2405},[59,11201,11202],{"class":2686}," type",[59,11204,11205],{"class":2413},"(m) ",[59,11207,2511],{"class":2405},[59,11209,11210],{"class":2413}," nn.Linear:\n",[59,11212,11213],{"class":2401,"line":2462},[59,11214,11215],{"class":2413},"            nn.init.xavier_uniform_(m.weight)\n",[59,11217,11218,11220,11222,11224,11226,11229,11232],{"class":2401,"line":2483},[59,11219,2505],{"class":2405},[59,11221,11202],{"class":2686},[59,11223,11205],{"class":2413},[59,11225,2511],{"class":2405},[59,11227,11228],{"class":2413}," nn.",[59,11230,11231],{"class":2455},"GRU",[59,11233,2459],{"class":2413},[59,11235,11236,11239,11241,11243],{"class":2401,"line":2491},[59,11237,11238],{"class":2405},"            for",[59,11240,11099],{"class":2413},[59,11242,9458],{"class":2405},[59,11244,11245],{"class":2413}," m._flat_weights_names:\n",[59,11247,11248,11251,11254,11257,11259,11262],{"class":2401,"line":2502},[59,11249,11250],{"class":2405},"                if",[59,11252,11253],{"class":9476}," \"",[59,11255,11256],{"class":9480},"weight",[59,11258,9477],{"class":9476},[59,11260,11261],{"class":2405}," in",[59,11263,11264],{"class":2413}," param:\n",[59,11266,11267],{"class":2401,"line":2519},[59,11268,11269],{"class":2413},"                    nn.init.xavier_uniform_(m._parameters[param])\n",[59,11271,11272],{"class":2401,"line":2535},[59,11273,2725],{"emptyLinePlaceholder":2724},[59,11275,11276],{"class":2401,"line":2543},[59,11277,11278],{"class":2413},"    net.apply(xavier_init_weights)\n",[59,11280,11281],{"class":2401,"line":2560},[59,11282,11283],{"class":2413},"    net.to(device)\n",[59,11285,11286,11289,11291,11294,11296,11298],{"class":2401,"line":2566},[59,11287,11288],{"class":2413},"    optimizer ",[59,11290,815],{"class":2405},[59,11292,11293],{"class":2413}," torch.optim.Adam(net.parameters(), ",[59,11295,11161],{"class":2471},[59,11297,815],{"class":2405},[59,11299,11300],{"class":2413},"lr)\n",[59,11302,11303,11306,11308],{"class":2401,"line":2591},[59,11304,11305],{"class":2413},"    loss ",[59,11307,815],{"class":2405},[59,11309,11310],{"class":2413}," MaskedSoftmaxCELoss()\n",[59,11312,11313],{"class":2401,"line":2604},[59,11314,11315],{"class":2413},"    net.train()\n",[59,11317,11318,11321,11323,11326,11329,11331,11333,11336,11338,11340,11343,11345,11347,11350,11352],{"class":2401,"line":2845},[59,11319,11320],{"class":2413},"    animator ",[59,11322,815],{"class":2405},[59,11324,11325],{"class":2413}," d2l.Animator(",[59,11327,11328],{"class":2471},"xlabel",[59,11330,815],{"class":2405},[59,11332,10892],{"class":9476},[59,11334,11335],{"class":9480},"epoch",[59,11337,10892],{"class":9476},[59,11339,2420],{"class":2413},[59,11341,11342],{"class":2471},"ylabel",[59,11344,815],{"class":2405},[59,11346,10892],{"class":9476},[59,11348,11349],{"class":9480},"loss",[59,11351,10892],{"class":9476},[59,11353,3826],{"class":2413},[59,11355,11356,11359,11361,11364,11367],{"class":2401,"line":4781},[59,11357,11358],{"class":2471},"                     xlim",[59,11360,815],{"class":2405},[59,11362,11363],{"class":2413},"[",[59,11365,11366],{"class":2455},"10",[59,11368,11369],{"class":2413},", num_epochs])\n",[59,11371,11372,11375,11378,11380,11382],{"class":2401,"line":5057},[59,11373,11374],{"class":2405},"    for",[59,11376,11377],{"class":2413}," epoch ",[59,11379,9458],{"class":2405},[59,11381,9462],{"class":9461},[59,11383,11384],{"class":2413},"(num_epochs):\n",[59,11386,11387,11390,11392],{"class":2401,"line":5063},[59,11388,11389],{"class":2413},"        timer ",[59,11391,815],{"class":2405},[59,11393,11394],{"class":2413}," d2l.Timer()\n",[59,11396,11397,11400,11402,11405,11407,11410],{"class":2401,"line":5069},[59,11398,11399],{"class":2413},"        metric ",[59,11401,815],{"class":2405},[59,11403,11404],{"class":2413}," d2l.Accumulator(",[59,11406,2821],{"class":2455},[59,11408,11409],{"class":2413},")  ",[59,11411,11412],{"class":2439},"# 训练损失总和，词元数量\n",[59,11414,11415,11417,11420,11422],{"class":2401,"line":5075},[59,11416,9452],{"class":2405},[59,11418,11419],{"class":2413}," batch ",[59,11421,9458],{"class":2405},[59,11423,11424],{"class":2413}," data_iter:\n",[59,11426,11427],{"class":2401,"line":5096},[59,11428,11429],{"class":2413},"            optimizer.zero_grad()\n",[59,11431,11432,11435,11437,11440,11442,11445,11447],{"class":2401,"line":5115},[59,11433,11434],{"class":2413},"            X, X_valid_len, Y, Y_valid_len ",[59,11436,815],{"class":2405},[59,11438,11439],{"class":2413}," [x.to(device) ",[59,11441,10530],{"class":2405},[59,11443,11444],{"class":2413}," x ",[59,11446,9458],{"class":2405},[59,11448,11449],{"class":2413}," batch]\n",[59,11451,11452,11455,11457,11460,11462,11465,11467,11470,11472,11475,11477],{"class":2401,"line":5134},[59,11453,11454],{"class":2413},"            bos ",[59,11456,815],{"class":2405},[59,11458,11459],{"class":2413}," torch.tensor([tgt_vocab[",[59,11461,10892],{"class":9476},[59,11463,11464],{"class":9480},"\u003Cbos>",[59,11466,10892],{"class":9476},[59,11468,11469],{"class":2413},"]] ",[59,11471,8729],{"class":2405},[59,11473,11474],{"class":2413}," Y.shape[",[59,11476,1754],{"class":2455},[59,11478,11479],{"class":2413},"],\n",[59,11481,11482,11485,11487,11490,11492,11494,11496,11498],{"class":2401,"line":5139},[59,11483,11484],{"class":2471},"                          device",[59,11486,815],{"class":2405},[59,11488,11489],{"class":2413},"device).reshape(",[59,11491,2553],{"class":2405},[59,11493,83],{"class":2455},[59,11495,2420],{"class":2413},[59,11497,83],{"class":2455},[59,11499,2480],{"class":2413},[59,11501,11502,11505,11507,11510,11512,11514,11516,11518,11520,11523,11525,11527],{"class":2401,"line":5155},[59,11503,11504],{"class":2413},"            dec_input ",[59,11506,815],{"class":2405},[59,11508,11509],{"class":2413}," torch.cat([bos, Y[",[59,11511,7184],{"class":7183},[59,11513,2420],{"class":2413},[59,11515,7184],{"class":7183},[59,11517,2553],{"class":2405},[59,11519,83],{"class":2455},[59,11521,11522],{"class":2413},"]], ",[59,11524,83],{"class":2455},[59,11526,11409],{"class":2413},[59,11528,11529],{"class":2439},"# 强制教学\n",[59,11531,11532,11535,11537],{"class":2401,"line":5161},[59,11533,11534],{"class":2413},"            Y_hat, _ ",[59,11536,815],{"class":2405},[59,11538,11539],{"class":2413}," net(X, dec_input, X_valid_len)\n",[59,11541,11542,11545,11547],{"class":2401,"line":5167},[59,11543,11544],{"class":2413},"            l ",[59,11546,815],{"class":2405},[59,11548,11549],{"class":2413}," loss(Y_hat, Y, Y_valid_len)\n",[59,11551,11552,11555],{"class":2401,"line":5177},[59,11553,11554],{"class":2413},"            l.sum().backward()      ",[59,11556,11557],{"class":2439},"# 损失函数的标量进行“反向传播”\n",[59,11559,11560,11563,11565],{"class":2401,"line":5201},[59,11561,11562],{"class":2413},"            grad_clipping(net, ",[59,11564,83],{"class":2455},[59,11566,2480],{"class":2413},[59,11568,11569,11572,11574],{"class":2401,"line":5206},[59,11570,11571],{"class":2413},"            num_tokens ",[59,11573,815],{"class":2405},[59,11575,11576],{"class":2413}," Y_valid_len.sum()\n",[59,11578,11579],{"class":2401,"line":5212},[59,11580,11581],{"class":2413},"            optimizer.step()\n",[59,11583,11584,11587],{"class":2401,"line":5226},[59,11585,11586],{"class":2405},"            with",[59,11588,11589],{"class":2413}," torch.no_grad():\n",[59,11591,11592],{"class":2401,"line":5231},[59,11593,11594],{"class":2413},"                metric.add(l.sum(), num_tokens)\n",[59,11596,11597,11599,11602,11604,11606,11608,11611,11614,11617,11620],{"class":2401,"line":5237},[59,11598,2505],{"class":2405},[59,11600,11601],{"class":2413}," (epoch ",[59,11603,5777],{"class":2405},[59,11605,2514],{"class":2455},[59,11607,7146],{"class":2413},[59,11609,11610],{"class":2405},"%",[59,11612,11613],{"class":2455}," 10",[59,11615,11616],{"class":2405}," ==",[59,11618,11619],{"class":2455}," 0",[59,11621,2459],{"class":2413},[59,11623,11624,11627,11629,11631,11634,11636,11638,11640,11643,11645],{"class":2401,"line":5252},[59,11625,11626],{"class":2413},"            animator.add(epoch ",[59,11628,5777],{"class":2405},[59,11630,2514],{"class":2455},[59,11632,11633],{"class":2413},", (metric[",[59,11635,1754],{"class":2455},[59,11637,7200],{"class":2413},[59,11639,2827],{"class":2405},[59,11641,11642],{"class":2413}," metric[",[59,11644,83],{"class":2455},[59,11646,11647],{"class":2413},"],))\n",[59,11649,11650,11653,11655,11658,11661,11664,11667,11669,11671,11673,11675,11677,11680,11683,11686,11688,11690,11692,11694,11696,11698,11701,11704,11706],{"class":2401,"line":9544},[59,11651,11652],{"class":9461},"    print",[59,11654,798],{"class":2413},[59,11656,11657],{"class":2405},"f",[59,11659,11660],{"class":9480},"'loss ",[59,11662,11663],{"class":2455},"{",[59,11665,11666],{"class":2413},"metric[",[59,11668,1754],{"class":2455},[59,11670,7200],{"class":2413},[59,11672,2827],{"class":2405},[59,11674,11642],{"class":2413},[59,11676,83],{"class":2455},[59,11678,11679],{"class":2413},"]",[59,11681,11682],{"class":2405},":.3f",[59,11684,11685],{"class":2455},"}",[59,11687,2420],{"class":9480},[59,11689,11663],{"class":2455},[59,11691,11666],{"class":2413},[59,11693,83],{"class":2455},[59,11695,7200],{"class":2413},[59,11697,2827],{"class":2405},[59,11699,11700],{"class":2413}," timer.stop()",[59,11702,11703],{"class":2405},":.1f",[59,11705,11685],{"class":2455},[59,11707,11708],{"class":9480}," '\n",[59,11710,11711,11714,11717,11719,11721,11724,11726,11728],{"class":2401,"line":9550},[59,11712,11713],{"class":2405},"        f",[59,11715,11716],{"class":9480},"'tokens/sec on ",[59,11718,11663],{"class":2455},[59,11720,9488],{"class":2686},[59,11722,11723],{"class":2413},"(device)",[59,11725,11685],{"class":2455},[59,11727,10892],{"class":9480},[59,11729,2480],{"class":2413},[15,11731,11732],{},"进行数据集的数据处理",[2390,11734,11736],{"className":2392,"code":11735,"language":2394,"meta":2395,"style":2395},"def read_data_nmt():\n    \"\"\"载入“英语－法语”数据集\n\n    Defined in :numref:`sec_machine_translation`\"\"\"\n    data_dir = d2l.download_extract('fra-eng')\n    with open(os.path.join(data_dir, 'fra.txt'), 'r',\n             encoding='utf-8') as f:\n        return f.read()\n\ndef preprocess_nmt(text):\n    \"\"\"预处理“英语－法语”数据集\n\n    Defined in :numref:`sec_machine_translation`\"\"\"\n    def no_space(char, prev_char):\n        return char in set(',.!?') and prev_char != ' '\n\n    # 使用空格替换不间断空格\n    # 使用小写字母替换大写字母\n    text = text.replace('\\u202f', ' ').replace('\\xa0', ' ').lower()\n    # 在单词和标点符号之间插入空格\n    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char\n           for i, char in enumerate(text)]\n    return ''.join(out)\n\ndef tokenize_nmt(text, num_examples=None):\n    \"\"\"词元化“英语－法语”数据数据集\n\n    Defined in :numref:`sec_machine_translation`\"\"\"\n    source, target = [], []\n    for i, line in enumerate(text.split('\\n')):\n        if num_examples and i > num_examples:\n            break\n        parts = line.split('\\t')\n        if len(parts) == 2:\n            source.append(parts[0].split(' '))\n            target.append(parts[1].split(' '))\n    return source, target\n\ndef load_data_nmt(batch_size, num_steps, num_examples=600):\n    \"\"\"返回翻译数据集的迭代器和词表\n\n    Defined in :numref:`subsec_mt_data_loading`\"\"\"\n    text = preprocess_nmt(read_data_nmt())\n    source, target = tokenize_nmt(text, num_examples)\n    src_vocab = d2l.Vocab(source, min_freq=2,\n                          reserved_tokens=['\u003Cpad>', '\u003Cbos>', '\u003Ceos>'])\n    tgt_vocab = d2l.Vocab(target, min_freq=2,\n                          reserved_tokens=['\u003Cpad>', '\u003Cbos>', '\u003Ceos>'])\n    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)\n    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)\n    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)\n    data_iter = d2l.load_array(data_arrays, batch_size)\n    return data_iter, src_vocab, tgt_vocab\n",[2397,11737,11738,11748,11753,11757,11762,11781,11810,11832,11839,11843,11857,11862,11866,11870,11889,11926,11930,11935,11940,11983,11988,12033,12048,12058,12062,12084,12089,12093,12097,12107,12131,12147,12152,12171,12186,12203,12218,12225,12229,12257,12262,12266,12271,12280,12289,12308,12343,12361,12393,12403,12413,12423,12433],{"__ignoreMap":2395},[59,11739,11740,11742,11745],{"class":2401,"line":2402},[59,11741,2406],{"class":2405},[59,11743,11744],{"class":2409}," read_data_nmt",[59,11746,11747],{"class":2413},"():\n",[59,11749,11750],{"class":2401,"line":2429},[59,11751,11752],{"class":2432},"    \"\"\"载入“英语－法语”数据集\n",[59,11754,11755],{"class":2401,"line":2436},[59,11756,2725],{"emptyLinePlaceholder":2724},[59,11758,11759],{"class":2401,"line":2443},[59,11760,11761],{"class":2432},"    Defined in :numref:`sec_machine_translation`\"\"\"\n",[59,11763,11764,11767,11769,11772,11774,11777,11779],{"class":2401,"line":2462},[59,11765,11766],{"class":2413},"    data_dir ",[59,11768,815],{"class":2405},[59,11770,11771],{"class":2413}," d2l.download_extract(",[59,11773,10892],{"class":9476},[59,11775,11776],{"class":9480},"fra-eng",[59,11778,10892],{"class":9476},[59,11780,2480],{"class":2413},[59,11782,11783,11786,11789,11792,11794,11797,11799,11802,11804,11806,11808],{"class":2401,"line":2483},[59,11784,11785],{"class":2405},"    with",[59,11787,11788],{"class":9461}," open",[59,11790,11791],{"class":2413},"(os.path.join(data_dir, ",[59,11793,10892],{"class":9476},[59,11795,11796],{"class":9480},"fra.txt",[59,11798,10892],{"class":9476},[59,11800,11801],{"class":2413},"), ",[59,11803,10892],{"class":9476},[59,11805,7826],{"class":9480},[59,11807,10892],{"class":9476},[59,11809,3826],{"class":2413},[59,11811,11812,11815,11817,11819,11822,11824,11826,11829],{"class":2401,"line":2491},[59,11813,11814],{"class":2471},"             encoding",[59,11816,815],{"class":2405},[59,11818,10892],{"class":9476},[59,11820,11821],{"class":9480},"utf-8",[59,11823,10892],{"class":9476},[59,11825,7146],{"class":2413},[59,11827,11828],{"class":2405},"as",[59,11830,11831],{"class":2413}," f:\n",[59,11833,11834,11836],{"class":2401,"line":2502},[59,11835,2465],{"class":2405},[59,11837,11838],{"class":2413}," f.read()\n",[59,11840,11841],{"class":2401,"line":2519},[59,11842,2725],{"emptyLinePlaceholder":2724},[59,11844,11845,11847,11850,11852,11855],{"class":2401,"line":2535},[59,11846,2406],{"class":2405},[59,11848,11849],{"class":2409}," preprocess_nmt",[59,11851,798],{"class":2413},[59,11853,11854],{"class":2416},"text",[59,11856,2426],{"class":2413},[59,11858,11859],{"class":2401,"line":2543},[59,11860,11861],{"class":2432},"    \"\"\"预处理“英语－法语”数据集\n",[59,11863,11864],{"class":2401,"line":2560},[59,11865,2725],{"emptyLinePlaceholder":2724},[59,11867,11868],{"class":2401,"line":2566},[59,11869,11761],{"class":2432},[59,11871,11872,11874,11877,11879,11882,11884,11887],{"class":2401,"line":2591},[59,11873,2657],{"class":2405},[59,11875,11876],{"class":2409}," no_space",[59,11878,798],{"class":2413},[59,11880,11881],{"class":2416},"char",[59,11883,2420],{"class":2413},[59,11885,11886],{"class":2416},"prev_char",[59,11888,2426],{"class":2413},[59,11890,11891,11893,11896,11898,11901,11903,11905,11908,11910,11912,11915,11918,11921,11924],{"class":2401,"line":2604},[59,11892,2465],{"class":2405},[59,11894,11895],{"class":2413}," char ",[59,11897,9458],{"class":2405},[59,11899,11900],{"class":2686}," set",[59,11902,798],{"class":2413},[59,11904,10892],{"class":9476},[59,11906,11907],{"class":9480},",.!?",[59,11909,10892],{"class":9476},[59,11911,7146],{"class":2413},[59,11913,11914],{"class":2405},"and",[59,11916,11917],{"class":2413}," prev_char ",[59,11919,11920],{"class":2405},"!=",[59,11922,11923],{"class":9476}," '",[59,11925,11708],{"class":9476},[59,11927,11928],{"class":2401,"line":2845},[59,11929,2725],{"emptyLinePlaceholder":2724},[59,11931,11932],{"class":2401,"line":4781},[59,11933,11934],{"class":2439},"    # 使用空格替换不间断空格\n",[59,11936,11937],{"class":2401,"line":5057},[59,11938,11939],{"class":2439},"    # 使用小写字母替换大写字母\n",[59,11941,11942,11945,11947,11950,11952,11956,11958,11960,11962,11964,11967,11969,11972,11974,11976,11978,11980],{"class":2401,"line":5063},[59,11943,11944],{"class":2413},"    text ",[59,11946,815],{"class":2405},[59,11948,11949],{"class":2413}," text.replace(",[59,11951,10892],{"class":9476},[59,11953,11955],{"class":11954},"sRfyP","\\u202f",[59,11957,10892],{"class":9476},[59,11959,2420],{"class":2413},[59,11961,10892],{"class":9476},[59,11963,11923],{"class":9476},[59,11965,11966],{"class":2413},").replace(",[59,11968,10892],{"class":9476},[59,11970,11971],{"class":11954},"\\xa0",[59,11973,10892],{"class":9476},[59,11975,2420],{"class":2413},[59,11977,10892],{"class":9476},[59,11979,11923],{"class":9476},[59,11981,11982],{"class":2413},").lower()\n",[59,11984,11985],{"class":2401,"line":5069},[59,11986,11987],{"class":2439},"    # 在单词和标点符号之间插入空格\n",[59,11989,11990,11993,11995,11997,11999,12001,12004,12006,12008,12010,12012,12014,12017,12020,12022,12024,12027,12030],{"class":2401,"line":5075},[59,11991,11992],{"class":2413},"    out ",[59,11994,815],{"class":2405},[59,11996,9592],{"class":2413},[59,11998,10892],{"class":9476},[59,12000,11923],{"class":9476},[59,12002,12003],{"class":2405}," +",[59,12005,11895],{"class":2413},[59,12007,11029],{"class":2405},[59,12009,9455],{"class":2413},[59,12011,11089],{"class":2405},[59,12013,11619],{"class":2455},[59,12015,12016],{"class":2405}," and",[59,12018,12019],{"class":2413}," no_space(char, text[i ",[59,12021,2553],{"class":2405},[59,12023,2514],{"class":2455},[59,12025,12026],{"class":2413},"]) ",[59,12028,12029],{"class":2405},"else",[59,12031,12032],{"class":2413}," char\n",[59,12034,12035,12038,12041,12043,12045],{"class":2401,"line":5096},[59,12036,12037],{"class":2405},"           for",[59,12039,12040],{"class":2413}," i, char ",[59,12042,9458],{"class":2405},[59,12044,9621],{"class":9461},[59,12046,12047],{"class":2413},"(text)]\n",[59,12049,12050,12052,12055],{"class":2401,"line":5115},[59,12051,4687],{"class":2405},[59,12053,12054],{"class":9476}," ''",[59,12056,12057],{"class":2413},".join(out)\n",[59,12059,12060],{"class":2401,"line":5134},[59,12061,2725],{"emptyLinePlaceholder":2724},[59,12063,12064,12066,12069,12071,12073,12075,12078,12080,12082],{"class":2401,"line":5139},[59,12065,2406],{"class":2405},[59,12067,12068],{"class":2409}," tokenize_nmt",[59,12070,798],{"class":2413},[59,12072,11854],{"class":2416},[59,12074,2420],{"class":2413},[59,12076,12077],{"class":2416},"num_examples",[59,12079,815],{"class":2405},[59,12081,2780],{"class":2455},[59,12083,2426],{"class":2413},[59,12085,12086],{"class":2401,"line":5155},[59,12087,12088],{"class":2432},"    \"\"\"词元化“英语－法语”数据数据集\n",[59,12090,12091],{"class":2401,"line":5161},[59,12092,2725],{"emptyLinePlaceholder":2724},[59,12094,12095],{"class":2401,"line":5167},[59,12096,11761],{"class":2432},[59,12098,12099,12102,12104],{"class":2401,"line":5177},[59,12100,12101],{"class":2413},"    source, target ",[59,12103,815],{"class":2405},[59,12105,12106],{"class":2413}," [], []\n",[59,12108,12109,12111,12114,12116,12118,12121,12123,12126,12128],{"class":2401,"line":5201},[59,12110,11374],{"class":2405},[59,12112,12113],{"class":2413}," i, line ",[59,12115,9458],{"class":2405},[59,12117,9621],{"class":9461},[59,12119,12120],{"class":2413},"(text.split(",[59,12122,10892],{"class":9476},[59,12124,12125],{"class":11954},"\\n",[59,12127,10892],{"class":9476},[59,12129,12130],{"class":2413},")):\n",[59,12132,12133,12135,12138,12140,12142,12144],{"class":2401,"line":5206},[59,12134,2505],{"class":2405},[59,12136,12137],{"class":2413}," num_examples ",[59,12139,11914],{"class":2405},[59,12141,9455],{"class":2413},[59,12143,11089],{"class":2405},[59,12145,12146],{"class":2413}," num_examples:\n",[59,12148,12149],{"class":2401,"line":5212},[59,12150,12151],{"class":2405},"            break\n",[59,12153,12154,12157,12159,12162,12164,12167,12169],{"class":2401,"line":5226},[59,12155,12156],{"class":2413},"        parts ",[59,12158,815],{"class":2405},[59,12160,12161],{"class":2413}," line.split(",[59,12163,10892],{"class":9476},[59,12165,12166],{"class":11954},"\\t",[59,12168,10892],{"class":9476},[59,12170,2480],{"class":2413},[59,12172,12173,12175,12177,12180,12182,12184],{"class":2401,"line":5231},[59,12174,2505],{"class":2405},[59,12176,9601],{"class":9461},[59,12178,12179],{"class":2413},"(parts) ",[59,12181,2511],{"class":2405},[59,12183,11068],{"class":2455},[59,12185,2459],{"class":2413},[59,12187,12188,12191,12193,12196,12198,12200],{"class":2401,"line":5237},[59,12189,12190],{"class":2413},"            source.append(parts[",[59,12192,1754],{"class":2455},[59,12194,12195],{"class":2413},"].split(",[59,12197,10892],{"class":9476},[59,12199,11923],{"class":9476},[59,12201,12202],{"class":2413},"))\n",[59,12204,12205,12208,12210,12212,12214,12216],{"class":2401,"line":5252},[59,12206,12207],{"class":2413},"            target.append(parts[",[59,12209,83],{"class":2455},[59,12211,12195],{"class":2413},[59,12213,10892],{"class":9476},[59,12215,11923],{"class":9476},[59,12217,12202],{"class":2413},[59,12219,12220,12222],{"class":2401,"line":9544},[59,12221,4687],{"class":2405},[59,12223,12224],{"class":2413}," source, target\n",[59,12226,12227],{"class":2401,"line":9550},[59,12228,2725],{"emptyLinePlaceholder":2724},[59,12230,12231,12233,12236,12238,12241,12243,12246,12248,12250,12252,12255],{"class":2401,"line":9556},[59,12232,2406],{"class":2405},[59,12234,12235],{"class":2409}," load_data_nmt",[59,12237,798],{"class":2413},[59,12239,12240],{"class":2416},"batch_size",[59,12242,2420],{"class":2413},[59,12244,12245],{"class":2416},"num_steps",[59,12247,2420],{"class":2413},[59,12249,12077],{"class":2416},[59,12251,815],{"class":2405},[59,12253,12254],{"class":2455},"600",[59,12256,2426],{"class":2413},[59,12258,12259],{"class":2401,"line":9583},[59,12260,12261],{"class":2432},"    \"\"\"返回翻译数据集的迭代器和词表\n",[59,12263,12264],{"class":2401,"line":9611},[59,12265,2725],{"emptyLinePlaceholder":2724},[59,12267,12268],{"class":2401,"line":9631},[59,12269,12270],{"class":2432},"    Defined in :numref:`subsec_mt_data_loading`\"\"\"\n",[59,12272,12273,12275,12277],{"class":2401,"line":9642},[59,12274,11944],{"class":2413},[59,12276,815],{"class":2405},[59,12278,12279],{"class":2413}," preprocess_nmt(read_data_nmt())\n",[59,12281,12282,12284,12286],{"class":2401,"line":9655},[59,12283,12101],{"class":2413},[59,12285,815],{"class":2405},[59,12287,12288],{"class":2413}," tokenize_nmt(text, num_examples)\n",[59,12290,12291,12294,12296,12299,12302,12304,12306],{"class":2401,"line":10158},[59,12292,12293],{"class":2413},"    src_vocab ",[59,12295,815],{"class":2405},[59,12297,12298],{"class":2413}," d2l.Vocab(source, ",[59,12300,12301],{"class":2471},"min_freq",[59,12303,815],{"class":2405},[59,12305,2821],{"class":2455},[59,12307,3826],{"class":2413},[59,12309,12310,12313,12315,12317,12319,12322,12324,12326,12328,12330,12332,12334,12336,12339,12341],{"class":2401,"line":10173},[59,12311,12312],{"class":2471},"                          reserved_tokens",[59,12314,815],{"class":2405},[59,12316,11363],{"class":2413},[59,12318,10892],{"class":9476},[59,12320,12321],{"class":9480},"\u003Cpad>",[59,12323,10892],{"class":9476},[59,12325,2420],{"class":2413},[59,12327,10892],{"class":9476},[59,12329,11464],{"class":9480},[59,12331,10892],{"class":9476},[59,12333,2420],{"class":2413},[59,12335,10892],{"class":9476},[59,12337,12338],{"class":9480},"\u003Ceos>",[59,12340,10892],{"class":9476},[59,12342,2532],{"class":2413},[59,12344,12345,12348,12350,12353,12355,12357,12359],{"class":2401,"line":10178},[59,12346,12347],{"class":2413},"    tgt_vocab ",[59,12349,815],{"class":2405},[59,12351,12352],{"class":2413}," d2l.Vocab(target, ",[59,12354,12301],{"class":2471},[59,12356,815],{"class":2405},[59,12358,2821],{"class":2455},[59,12360,3826],{"class":2413},[59,12362,12363,12365,12367,12369,12371,12373,12375,12377,12379,12381,12383,12385,12387,12389,12391],{"class":2401,"line":10193},[59,12364,12312],{"class":2471},[59,12366,815],{"class":2405},[59,12368,11363],{"class":2413},[59,12370,10892],{"class":9476},[59,12372,12321],{"class":9480},[59,12374,10892],{"class":9476},[59,12376,2420],{"class":2413},[59,12378,10892],{"class":9476},[59,12380,11464],{"class":9480},[59,12382,10892],{"class":9476},[59,12384,2420],{"class":2413},[59,12386,10892],{"class":9476},[59,12388,12338],{"class":9480},[59,12390,10892],{"class":9476},[59,12392,2532],{"class":2413},[59,12394,12395,12398,12400],{"class":2401,"line":10222},[59,12396,12397],{"class":2413},"    src_array, src_valid_len ",[59,12399,815],{"class":2405},[59,12401,12402],{"class":2413}," build_array_nmt(source, src_vocab, num_steps)\n",[59,12404,12405,12408,12410],{"class":2401,"line":10241},[59,12406,12407],{"class":2413},"    tgt_array, tgt_valid_len ",[59,12409,815],{"class":2405},[59,12411,12412],{"class":2413}," build_array_nmt(target, tgt_vocab, num_steps)\n",[59,12414,12415,12418,12420],{"class":2401,"line":10262},[59,12416,12417],{"class":2413},"    data_arrays ",[59,12419,815],{"class":2405},[59,12421,12422],{"class":2413}," (src_array, src_valid_len, tgt_array, tgt_valid_len)\n",[59,12424,12425,12428,12430],{"class":2401,"line":10282},[59,12426,12427],{"class":2413},"    data_iter ",[59,12429,815],{"class":2405},[59,12431,12432],{"class":2413}," d2l.load_array(data_arrays, batch_size)\n",[59,12434,12435,12437],{"class":2401,"line":10293},[59,12436,4687],{"class":2405},[59,12438,12439],{"class":2413}," data_iter, src_vocab, tgt_vocab\n",[15,12441,12442],{},"之后按照 transformer 架构创建编码器-解码器架构，",[2390,12444,12446],{"className":2392,"code":12445,"language":2394,"meta":2395,"style":2395},"num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10\nlr, num_epochs, device = 0.005, 200, torch.device('cuda:0') if torch.cuda.device_count() >= 1 else torch.device('cpu')\nffn_num_hiddens, num_heads = 64, 4\n\ntrain_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n\nencoder = TransformerEncoder(\n    len(src_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers,\n    dropout)\ndecoder = TransformerDecoder(\n    len(tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers,\n    dropout)\nnet = EncoderDecoder(encoder, decoder)\ntrain_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)\n",[2397,12447,12448,12477,12529,12544,12548,12558,12562,12572,12580,12585,12595,12602,12606,12616],{"__ignoreMap":2395},[59,12449,12450,12453,12455,12458,12460,12462,12464,12467,12469,12472,12474],{"class":2401,"line":2402},[59,12451,12452],{"class":2413},"num_hiddens, num_layers, dropout, batch_size, num_steps ",[59,12454,815],{"class":2405},[59,12456,12457],{"class":2455}," 32",[59,12459,2420],{"class":2413},[59,12461,2821],{"class":2455},[59,12463,2420],{"class":2413},[59,12465,12466],{"class":2455},"0.1",[59,12468,2420],{"class":2413},[59,12470,12471],{"class":2455},"64",[59,12473,2420],{"class":2413},[59,12475,12476],{"class":2455},"10\n",[59,12478,12479,12482,12484,12487,12489,12492,12495,12497,12500,12502,12504,12506,12509,12512,12514,12517,12520,12522,12525,12527],{"class":2401,"line":2429},[59,12480,12481],{"class":2413},"lr, num_epochs, device ",[59,12483,815],{"class":2405},[59,12485,12486],{"class":2455}," 0.005",[59,12488,2420],{"class":2413},[59,12490,12491],{"class":2455},"200",[59,12493,12494],{"class":2413},", torch.device(",[59,12496,10892],{"class":9476},[59,12498,12499],{"class":9480},"cuda:0",[59,12501,10892],{"class":9476},[59,12503,7146],{"class":2413},[59,12505,11029],{"class":2405},[59,12507,12508],{"class":2413}," torch.cuda.device_count() ",[59,12510,12511],{"class":2405},">=",[59,12513,2514],{"class":2455},[59,12515,12516],{"class":2405}," else",[59,12518,12519],{"class":2413}," torch.device(",[59,12521,10892],{"class":9476},[59,12523,12524],{"class":9480},"cpu",[59,12526,10892],{"class":9476},[59,12528,2480],{"class":2413},[59,12530,12531,12534,12536,12539,12541],{"class":2401,"line":2436},[59,12532,12533],{"class":2413},"ffn_num_hiddens, num_heads ",[59,12535,815],{"class":2405},[59,12537,12538],{"class":2455}," 64",[59,12540,2420],{"class":2413},[59,12542,12543],{"class":2455},"4\n",[59,12545,12546],{"class":2401,"line":2443},[59,12547,2725],{"emptyLinePlaceholder":2724},[59,12549,12550,12553,12555],{"class":2401,"line":2462},[59,12551,12552],{"class":2413},"train_iter, src_vocab, tgt_vocab ",[59,12554,815],{"class":2405},[59,12556,12557],{"class":2413}," d2l.load_data_nmt(batch_size, num_steps)\n",[59,12559,12560],{"class":2401,"line":2483},[59,12561,2725],{"emptyLinePlaceholder":2724},[59,12563,12564,12567,12569],{"class":2401,"line":2491},[59,12565,12566],{"class":2413},"encoder ",[59,12568,815],{"class":2405},[59,12570,12571],{"class":2413}," TransformerEncoder(\n",[59,12573,12574,12577],{"class":2401,"line":2502},[59,12575,12576],{"class":9461},"    len",[59,12578,12579],{"class":2413},"(src_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers,\n",[59,12581,12582],{"class":2401,"line":2519},[59,12583,12584],{"class":2413},"    dropout)\n",[59,12586,12587,12590,12592],{"class":2401,"line":2535},[59,12588,12589],{"class":2413},"decoder ",[59,12591,815],{"class":2405},[59,12593,12594],{"class":2413}," TransformerDecoder(\n",[59,12596,12597,12599],{"class":2401,"line":2543},[59,12598,12576],{"class":9461},[59,12600,12601],{"class":2413},"(tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers,\n",[59,12603,12604],{"class":2401,"line":2560},[59,12605,12584],{"class":2413},[59,12607,12608,12611,12613],{"class":2401,"line":2566},[59,12609,12610],{"class":2413},"net ",[59,12612,815],{"class":2405},[59,12614,12615],{"class":2413}," EncoderDecoder(encoder, decoder)\n",[59,12617,12618],{"class":2401,"line":2591},[59,12619,12620],{"class":2413},"train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)\n",[11,12622,12623],{"id":12623},"参考文献",[15,12625,12626],{},[22,12627,33],{"href":31,"rel":12628},[26],[15,12630,12631],{},[22,12632,27],{"href":24,"rel":12633},[26],[12635,12636,12637],"style",{},"html pre.shiki code .szJfE, html code.shiki .szJfE{--shiki-default:#D73A49;--shiki-dark:#FF79C6}html pre.shiki code .sCLZk, html code.shiki .sCLZk{--shiki-default:#6F42C1;--shiki-dark:#50FA7B}html pre.shiki code .scbbO, html code.shiki .scbbO{--shiki-default:#24292E;--shiki-dark:#F8F8F2}html pre.shiki code .syNf4, html code.shiki .syNf4{--shiki-default:#24292E;--shiki-default-font-style:inherit;--shiki-dark:#FFB86C;--shiki-dark-font-style:italic}html pre.shiki code .seLWX, html code.shiki .seLWX{--shiki-default:#032F62;--shiki-dark:#6272A4}html pre.shiki code .sfgPZ, html code.shiki .sfgPZ{--shiki-default:#6A737D;--shiki-dark:#6272A4}html pre.shiki code .soDru, html code.shiki .soDru{--shiki-default:#005CC5;--shiki-dark:#BD93F9}html pre.shiki code .sQkXh, html code.shiki .sQkXh{--shiki-default:#E36209;--shiki-default-font-style:inherit;--shiki-dark:#FFB86C;--shiki-dark-font-style:italic}html .default .shiki span {color: var(--shiki-default);background: var(--shiki-default-bg);font-style: var(--shiki-default-font-style);font-weight: var(--shiki-default-font-weight);text-decoration: var(--shiki-default-text-decoration);}html .shiki span {color: var(--shiki-default);background: var(--shiki-default-bg);font-style: var(--shiki-default-font-style);font-weight: var(--shiki-default-font-weight);text-decoration: var(--shiki-default-text-decoration);}html .dark .shiki span {color: var(--shiki-dark);background: var(--shiki-dark-bg);font-style: var(--shiki-dark-font-style);font-weight: var(--shiki-dark-font-weight);text-decoration: var(--shiki-dark-text-decoration);}html.dark .shiki span {color: var(--shiki-dark);background: var(--shiki-dark-bg);font-style: var(--shiki-dark-font-style);font-weight: var(--shiki-dark-font-weight);text-decoration: var(--shiki-dark-text-decoration);}html pre.shiki code .skCyd, html code.shiki .skCyd{--shiki-default:#6F42C1;--shiki-dark:#8BE9FD}html pre.shiki code .sDP9b, html code.shiki .sDP9b{--shiki-default:#6F42C1;--shiki-default-font-style:inherit;--shiki-dark:#8BE9FD;--shiki-dark-font-style:italic}html pre.shiki code .sD3jg, html code.shiki .sD3jg{--shiki-default:#24292E;--shiki-default-font-style:inherit;--shiki-dark:#BD93F9;--shiki-dark-font-style:italic}html pre.shiki code .sPGBF, html code.shiki .sPGBF{--shiki-default:#005CC5;--shiki-default-font-style:inherit;--shiki-dark:#8BE9FD;--shiki-dark-font-style:italic}html pre.shiki code .sJti5, html code.shiki .sJti5{--shiki-default:#005CC5;--shiki-default-font-style:inherit;--shiki-dark:#BD93F9;--shiki-dark-font-style:italic}html pre.shiki code .sDoOe, html code.shiki .sDoOe{--shiki-default:#24292E;--shiki-dark:#FF79C6}html pre.shiki code .sDgm9, html code.shiki .sDgm9{--shiki-default:#005CC5;--shiki-dark:#8BE9FD}html pre.shiki code .sMWOi, html code.shiki .sMWOi{--shiki-default:#032F62;--shiki-dark:#E9F284}html pre.shiki code .sEzAm, html code.shiki .sEzAm{--shiki-default:#032F62;--shiki-dark:#F1FA8C}html pre.shiki code .sRfyP, html code.shiki .sRfyP{--shiki-default:#005CC5;--shiki-dark:#FF79C6}",{"title":2395,"searchDepth":2429,"depth":2429,"links":12639},[12640,12641,12642,12646,12647,12648,12649,12650,12651,12652,12653],{"id":13,"depth":2429,"text":13},{"id":37,"depth":2429,"text":37},{"id":47,"depth":2429,"text":47,"children":12643},[12644,12645],{"id":51,"depth":2436,"text":51},{"id":2870,"depth":2436,"text":2870},{"id":5262,"depth":2429,"text":5262},{"id":7297,"depth":2429,"text":7297},{"id":7793,"depth":2429,"text":7793},{"id":8130,"depth":2429,"text":8130},{"id":9044,"depth":2429,"text":9045},{"id":9663,"depth":2429,"text":9664},{"id":10670,"depth":2429,"text":10670},{"id":12623,"depth":2429,"text":12623},"距离首次学习 transformer 已经过去一年，内容忘的差不多了，决定复习一下。","md",{"date":12657,"image":12658,"alt":6,"tags":12659,"published":2724},"30st Jan 2026","/blogs-img/blog1.jpg",[12660,12661],"deep-learning","学习笔记","/blogs/transformer",{"title":6,"description":12654},"blogs/1. transformer复习笔记","59tIbWfqdpJGJRh5PHVoEHC170F6vnech8kVXdG8MGw",1774096958854]