@@ -1644,6 +1644,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
1644
1644
};
1645
1645
}
1646
1646
1647
+ // infill
1648
+
1649
+ // #define GGML_DEBUG_SAMPLER_INFILL
1650
+
1651
+ struct llama_sampler_infill {
1652
+ const struct llama_vocab * vocab;
1653
+ };
1654
+
1655
+ static const char * llama_sampler_infill_name (const struct llama_sampler * /* smpl*/ ) {
1656
+ return " infill" ;
1657
+ }
1658
+
1659
+ static void llama_sampler_infill_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1660
+ auto * ctx = (llama_sampler_infill *) smpl->ctx ;
1661
+
1662
+ llama_sampler_softmax_impl (cur_p);
1663
+
1664
+ #if defined(GGML_DEBUG_SAMPLER_INFILL)
1665
+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
1666
+ #else
1667
+ #define LOG_DBG_CUR (...)
1668
+ #endif
1669
+
1670
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1671
+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1672
+ }
1673
+
1674
+ float p_txt_sum = 0 .0f ;
1675
+ float p_eog_sum = 0 .0f ;
1676
+
1677
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1678
+ if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1679
+ p_eog_sum += cur_p->data [i].p ;
1680
+ } else {
1681
+ p_txt_sum += cur_p->data [i].p ;
1682
+ }
1683
+ }
1684
+
1685
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED (rat);
1686
+
1687
+ LOG_DBG_CUR (" %s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , __func__, p_txt_sum, p_eog_sum, rat, cur_p->size );
1688
+
1689
+ if (3 *p_eog_sum*cur_p->size > p_txt_sum) {
1690
+ LOG_DBG_CUR (" %s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n " , __func__, p_txt_sum/p_eog_sum);
1691
+
1692
+ // keep just the EOG tokens
1693
+ const auto size_org = cur_p->size ;
1694
+
1695
+ cur_p->size = 0 ;
1696
+
1697
+ float p_sum = 0 .0f ;
1698
+
1699
+ for (size_t i = 0 ; i < size_org; ++i) {
1700
+ if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1701
+ p_sum += cur_p->data [i].p ;
1702
+
1703
+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1704
+ }
1705
+ }
1706
+
1707
+ // normalize probs
1708
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1709
+ cur_p->data [i].p /= p_sum;
1710
+ }
1711
+
1712
+ return ;
1713
+ }
1714
+
1715
+ size_t n_combined = 0 ; GGML_UNUSED (n_combined);
1716
+
1717
+ // combine tokens with common prefix
1718
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1719
+ for (size_t j = 0 ; j < cur_p->size ; ++j) {
1720
+ if (cur_p->data [i].logit == -INFINITY) {
1721
+ break ;
1722
+ }
1723
+
1724
+ if (i == j || cur_p->data [j].logit == -INFINITY) {
1725
+ continue ;
1726
+ }
1727
+
1728
+ if (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1729
+ if (cur_p->data [i].p > cur_p->data [j].p ) {
1730
+ cur_p->data [i].p += cur_p->data [j].p ;
1731
+ cur_p->data [j].logit = -INFINITY;
1732
+ cur_p->data [j].p = 0 .0f ;
1733
+ } else {
1734
+ cur_p->data [j].p += cur_p->data [i].p ;
1735
+ cur_p->data [i].logit = -INFINITY;
1736
+ cur_p->data [i].p = 0 .0f ;
1737
+ }
1738
+
1739
+ n_combined++;
1740
+ }
1741
+ }
1742
+ }
1743
+
1744
+ size_t n_non_eog = 0 ;
1745
+
1746
+ size_t size_org = cur_p->size ;
1747
+
1748
+ float p_sum = 0 .0f ;
1749
+ float thold = 0 .2f ;
1750
+
1751
+ cur_p->size = 0 ;
1752
+
1753
+ LOG_DBG_CUR (" %s: n_combined = %zu, applying thold = %.3f\n " , __func__, n_combined, thold);
1754
+
1755
+ for (size_t i = 0 ; i < size_org; ++i) {
1756
+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1757
+
1758
+ if (cur_p->data [i].p < thold && !is_eog) {
1759
+ continue ;
1760
+ }
1761
+
1762
+ if (!is_eog) {
1763
+ ++n_non_eog;
1764
+ }
1765
+
1766
+ p_sum += cur_p->data [i].p ;
1767
+
1768
+ // keep this token
1769
+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1770
+ }
1771
+
1772
+ LOG_DBG_CUR (" %s: n_non_eog = %zu\n " , __func__, n_non_eog);
1773
+
1774
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
1775
+ if (n_non_eog == 0 ) {
1776
+ cur_p->size = 1 ;
1777
+ cur_p->data [0 ].id = llama_token_eot_impl (*ctx->vocab );
1778
+ cur_p->data [0 ].logit = 1 .0f ;
1779
+
1780
+ return ;
1781
+ }
1782
+
1783
+ // normalize probs
1784
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1785
+ cur_p->data [i].p /= p_sum;
1786
+
1787
+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1788
+ }
1789
+
1790
+ size_org = cur_p->size ;
1791
+ p_sum = 0 .0f ;
1792
+ thold = 1.0 /(n_non_eog + 1 );
1793
+
1794
+ cur_p->size = 0 ;
1795
+
1796
+ LOG_DBG_CUR (" %s: applying thold = %.3f\n " , __func__, thold);
1797
+
1798
+ for (size_t i = 0 ; i < size_org; ++i) {
1799
+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1800
+
1801
+ if (cur_p->data [i].p < thold && !is_eog) {
1802
+ continue ;
1803
+ }
1804
+
1805
+ p_sum += cur_p->data [i].p ;
1806
+
1807
+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1808
+ }
1809
+
1810
+ // normalize probs
1811
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1812
+ cur_p->data [i].p /= p_sum;
1813
+
1814
+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1815
+ }
1816
+
1817
+ #undef LOG_DBG_CUR
1818
+ }
1819
+
1820
+ static struct llama_sampler * llama_sampler_infill_clone (const struct llama_sampler * smpl) {
1821
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx ;
1822
+ return llama_sampler_init_infill_impl (*ctx->vocab );
1823
+ }
1824
+
1825
+ static void llama_sampler_infill_free (struct llama_sampler * smpl) {
1826
+ delete (llama_sampler_infill *) smpl->ctx ;
1827
+ }
1828
+
1829
+ static struct llama_sampler_i llama_sampler_infill_i = {
1830
+ /* .name = */ llama_sampler_infill_name,
1831
+ /* .accept = */ nullptr ,
1832
+ /* .apply = */ llama_sampler_infill_apply,
1833
+ /* .reset = */ nullptr ,
1834
+ /* .clone = */ llama_sampler_infill_clone,
1835
+ /* .free = */ llama_sampler_infill_free,
1836
+ };
1837
+
1838
+ struct llama_sampler * llama_sampler_init_infill_impl (
1839
+ const struct llama_vocab & vocab) {
1840
+ return new llama_sampler {
1841
+ /* .iface = */ &llama_sampler_infill_i,
1842
+ /* .ctx = */ new llama_sampler_infill {
1843
+ /* .vocab = */ &vocab,
1844
+ },
1845
+ };
1846
+ }
1847
+
1647
1848
// utils
1648
1849
1649
1850
uint32_t llama_sampler_get_seed (const struct llama_sampler * smpl) {
0 commit comments