@@ -20,7 +20,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
2020 const wchar_t * lpWideCharStr, int cchWideChar,
2121 char * lpMultiByteStr, int cbMultiByte,
2222 const char * lpDefaultChar, bool * lpUsedDefaultChar);
23+ #define ENABLE_LINE_INPUT 0x0002
24+ #define ENABLE_ECHO_INPUT 0x0004
2325#define CP_UTF8 65001
26+ #define CONSOLE_CHAR_TYPE wchar_t
27+ #define CONSOLE_GET_CHAR () getwchar()
28+ #define CONSOLE_EOF WEOF
29+ #else
30+ #include < unistd.h>
31+ #define CONSOLE_CHAR_TYPE char
32+ #define CONSOLE_GET_CHAR () getchar()
33+ #define CONSOLE_EOF EOF
2434#endif
2535
2636bool gpt_params_parse (int argc, char ** argv, gpt_params & params) {
@@ -160,6 +170,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
160170 params.interactive = true ;
161171 } else if (arg == " --interactive-first" ) {
162172 params.interactive_start = true ;
173+ } else if (arg == " --author-mode" ) {
174+ params.author_mode = true ;
163175 } else if (arg == " -ins" || arg == " --instruct" ) {
164176 params.instruct = true ;
165177 } else if (arg == " --color" ) {
@@ -222,6 +234,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
222234 fprintf (stderr, " -i, --interactive run in interactive mode\n " );
223235 fprintf (stderr, " --interactive-first run in interactive mode and wait for input right away\n " );
224236 fprintf (stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n " );
237+ fprintf (stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\ '\n " );
225238 fprintf (stderr, " -r PROMPT, --reverse-prompt PROMPT\n " );
226239 fprintf (stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n " );
227240 fprintf (stderr, " specified more than once for multiple prompts).\n " );
@@ -293,7 +306,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
293306}
294307
295308/* Keep track of current color of output, and emit ANSI code if it changes. */
296- void set_console_color (console_state & con_st, console_color_t color) {
309+ void console_set_color (console_state & con_st, console_color_t color) {
297310 if (con_st.use_color && con_st.color != color) {
298311 switch (color) {
299312 case CONSOLE_COLOR_DEFAULT:
@@ -310,8 +323,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
310323 }
311324}
312325
326+ void console_init (console_state & con_st) {
313327#if defined (_WIN32)
314- void win32_console_init ( bool enable_color) {
328+ // Windows-specific console initialization
315329 unsigned long dwMode = 0 ;
316330 void * hConOut = GetStdHandle ((unsigned long )-11 ); // STD_OUTPUT_HANDLE (-11)
317331 if (!hConOut || hConOut == (void *)-1 || !GetConsoleMode (hConOut, &dwMode)) {
@@ -322,7 +336,7 @@ void win32_console_init(bool enable_color) {
322336 }
323337 if (hConOut) {
324338 // Enable ANSI colors on Windows 10+
325- if (enable_color && !(dwMode & 0x4 )) {
339+ if (con_st. use_color && !(dwMode & 0x4 )) {
326340 SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
327341 }
328342 // Set console output codepage to UTF8
@@ -332,9 +346,46 @@ void win32_console_init(bool enable_color) {
332346 if (hConIn && hConIn != (void *)-1 && GetConsoleMode (hConIn, &dwMode)) {
333347 // Set console input codepage to UTF16
334348 _setmode (_fileno (stdin), _O_WTEXT);
349+
350+ // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
351+ dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
352+ SetConsoleMode (hConIn, dwMode);
353+ }
354+ #else
355+ // POSIX-specific console initialization
356+ struct termios new_termios;
357+ tcgetattr (STDIN_FILENO, &con_st.prev_state );
358+ new_termios = con_st.prev_state ;
359+ new_termios.c_lflag &= ~(ICANON | ECHO);
360+ new_termios.c_cc [VMIN] = 1 ;
361+ new_termios.c_cc [VTIME] = 0 ;
362+ tcsetattr (STDIN_FILENO, TCSANOW, &new_termios);
363+ #endif
364+ }
365+
366+ void console_cleanup (console_state & con_st) {
367+ #if !defined(_WIN32)
368+ // Restore the terminal settings on POSIX systems
369+ tcsetattr (STDIN_FILENO, TCSANOW, &con_st.prev_state );
370+ #endif
371+
372+ // Reset console color
373+ console_set_color (con_st, CONSOLE_COLOR_DEFAULT);
374+ }
375+
376+ // Helper function to remove the last UTF-8 character from a string
377+ void remove_last_utf8_char (std::string & line) {
378+ if (line.empty ()) return ;
379+ size_t pos = line.length () - 1 ;
380+
381+ // Find the start of the last UTF-8 character (checking up to 4 bytes back)
382+ for (size_t i = 0 ; i < 3 && pos > 0 ; ++i, --pos) {
383+ if ((line[pos] & 0xC0 ) != 0x80 ) break ; // Found the start of the character
335384 }
385+ line.erase (pos);
336386}
337387
388+ #if defined (_WIN32)
338389// Convert a wide Unicode string to an UTF8 string
339390void win32_utf8_encode (const std::wstring & wstr, std::string & str) {
340391 int size_needed = WideCharToMultiByte (CP_UTF8, 0 , &wstr[0 ], (int )wstr.size (), NULL , 0 , NULL , NULL );
@@ -343,3 +394,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
343394 str = strTo;
344395}
345396#endif
397+
398+ bool console_readline (console_state & con_st, std::string & line) {
399+ line.clear ();
400+ bool is_special_char = false ;
401+ bool end_of_stream = false ;
402+
403+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
404+
405+ CONSOLE_CHAR_TYPE input_char;
406+ while (true ) {
407+ fflush (stdout); // Ensure all output is displayed before waiting for input
408+ input_char = CONSOLE_GET_CHAR ();
409+
410+ if (input_char == ' \r ' || input_char == ' \n ' ) {
411+ break ;
412+ }
413+
414+ if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/ ) {
415+ end_of_stream = true ;
416+ break ;
417+ }
418+
419+ if (is_special_char) {
420+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
421+ putchar (' \b ' );
422+ putchar (line.back ());
423+ is_special_char = false ;
424+ }
425+
426+ if (input_char == ' \033 ' ) { // Escape sequence
427+ CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR ();
428+ if (code == ' [' ) {
429+ // Discard the rest of the escape sequence
430+ while ((code = CONSOLE_GET_CHAR ()) != CONSOLE_EOF) {
431+ if ((code >= ' A' && code <= ' Z' ) || (code >= ' a' && code <= ' z' ) || code == ' ~' ) {
432+ break ;
433+ }
434+ }
435+ }
436+ } else if (input_char == 0x08 || input_char == 0x7F ) { // Backspace
437+ if (!line.empty ()) {
438+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
439+ remove_last_utf8_char (line);
440+ }
441+ } else if (input_char < 32 ) {
442+ // Ignore control characters
443+ } else {
444+ #if defined(_WIN32)
445+ std::string utf8_char;
446+ win32_utf8_encode (std::wstring (1 , input_char), utf8_char);
447+ line += utf8_char;
448+ fputs (utf8_char.c_str (), stdout);
449+ #else
450+ line += input_char;
451+ putchar (input_char);
452+ #endif
453+ }
454+
455+ if (!line.empty () && (line.back () == ' \\ ' || line.back () == ' /' )) {
456+ console_set_color (con_st, CONSOLE_COLOR_PROMPT);
457+ putchar (' \b ' );
458+ putchar (line.back ());
459+ is_special_char = true ;
460+ }
461+ }
462+
463+ bool has_more = con_st.author_mode ;
464+ if (is_special_char) {
465+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
466+
467+ char last = line.back ();
468+ line.pop_back ();
469+ if (last == ' \\ ' ) {
470+ line += ' \n ' ;
471+ putchar (' \n ' );
472+ has_more = !has_more;
473+ } else {
474+ // llama doesn't seem to process a single space
475+ if (line.length () == 1 && line.back () == ' ' ) {
476+ line.clear ();
477+ putchar (' \b ' );
478+ }
479+ has_more = false ;
480+ }
481+ } else {
482+ if (end_of_stream) {
483+ has_more = false ;
484+ } else {
485+ line += ' \n ' ;
486+ putchar (' \n ' );
487+ }
488+ }
489+
490+ fflush (stdout);
491+ return has_more;
492+ }
0 commit comments