Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions guidance/library/_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

log = logging.getLogger(__name__)


async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_tokens=500, n=1, stream=None,
temperature=0.0, top_p=1.0, logprobs=None, pattern=None, hidden=False, list_append=False,
save_prompt=False, token_healing=None, function_call="none", _parser_context=None, **llm_kwargs):
Expand Down Expand Up @@ -68,7 +69,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
next_next_node = _parser_context["next_next_node"]
prev_node = _parser_context["prev_node"]
# partial_output = _parser_context["partial_output"]
pos = len(variable_stack["@raw_prefix"]) # save the current position in the prefix
pos = len(variable_stack["@raw_prefix"]) # save the current position in the prefix

if hidden:
variable_stack.push({"@raw_prefix": variable_stack["@raw_prefix"]})
Expand Down Expand Up @@ -104,7 +105,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
end_tag = "</"+m.group(1)+">"
if next_text.startswith(end_tag):
stop = end_tag

# fall back to the next node's text (this was too easy to accidentally trigger, so we disable it now)
# if stop is None:
# stop = next_text
Expand Down Expand Up @@ -160,17 +161,17 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
variable_stack[name+"_logprobs"].append([])
assert len(len(variable_stack[name])) == len(len(variable_stack[name+"_logprobs"]))
async for resp in gen_obj:
await asyncio.sleep(0) # allow other tasks to run
await asyncio.sleep(0) # allow other tasks to run
#log("parser.should_stop = " + str(parser.should_stop))
if parser.should_stop:
#log("Stopping generation")
break
# log.debug("resp", resp)
new_text = resp["choices"][0].get("text", "")
new_text = resp.choices[0].delta.content or ""
generated_value += new_text
variable_stack["@raw_prefix"] += new_text
if logprobs is not None:
logprobs_out.extend(resp["choices"][0]["logprobs"]["top_logprobs"])
logprobs_out.extend(resp.choices[0].logprobs["top_logprobs"])
if list_append:
variable_stack[name][list_ind] = generated_value
if logprobs is not None:
Expand All @@ -179,13 +180,13 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
variable_stack[name] = generated_value
if logprobs is not None:
variable_stack[name+"_logprobs"] = logprobs_out

# save the final stopping text if requested
if save_stop_text is not False:
if save_stop_text is True:
save_stop_text = name+"_stop_text"
variable_stack[save_stop_text] = resp["choices"][0].get('stop_text', None)
variable_stack[save_stop_text] = resp.choices[0].get('stop_text', None)

if hasattr(gen_obj, 'close'):
gen_obj.close()
generated_value += suffix
Expand All @@ -199,25 +200,25 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
new_content = variable_stack["@raw_prefix"][pos:]
variable_stack.pop()
variable_stack["@raw_prefix"] += "{{!--GHIDDEN:"+new_content.replace("--}}", "--_END_END")+"--}}"

# stop executing if we were interrupted
if parser.should_stop:
parser.executing = False
parser.should_stop = False
return
else:
assert not isinstance(gen_obj, list), "Streaming is only supported for n=1"
generated_values = [prefix+choice["text"]+suffix for choice in gen_obj["choices"]]
generated_values = [prefix+choice["text"]+suffix for choice in gen_obj.choices]
if list_append:
value_list = variable_stack.get(name, [])
value_list.append(generated_values)
if logprobs is not None:
logprobs_list = variable_stack.get(name+"_logprobs", [])
logprobs_list.append([choice["logprobs"]["top_logprobs"] for choice in gen_obj["choices"]])
logprobs_list.append([choice.logprobs["top_logprobs"] for choice in gen_obj.choices])
elif name is not None:
variable_stack[name] = generated_values
if logprobs is not None:
variable_stack[name+"_logprobs"] = [choice["logprobs"]["top_logprobs"] for choice in gen_obj["choices"]]
variable_stack[name+"_logprobs"] = [choice.logprobs["top_logprobs"] for choice in gen_obj.choices]

if not hidden:
# TODO: we could enable the parsing to branch into multiple paths here, but for now we just complete the program with the first prefix
Expand All @@ -228,7 +229,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
# we mostly support this so that the echo=False hiding behavior does not make multiple outputs more complicated than it needs to be in the UX
# if echo:
# variable_stack["@raw_prefix"] += generated_value

id = uuid.uuid4().hex
l = len(generated_values)
out = "{{!--" + f"GMARKERmany_generate_start_{not hidden}_{l}${id}$" + "--}}"
Expand All @@ -246,4 +247,4 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
# return "".join([v for v in generated_values])
else:
# pop off the variable context we pushed since we are hidden
variable_stack.pop()
variable_stack.pop()
130 changes: 63 additions & 67 deletions guidance/llms/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import regex

from ._llm import LLM, LLMSession, SyncSession

from openai import RateLimitError, APIError, APITimeoutError, APIStatusError, AsyncOpenAI

class MalformedPromptException(Exception):
pass
Expand Down Expand Up @@ -68,32 +68,32 @@ async def add_text_to_chat_mode_generator(chat_mode):
in_function_call = False
async for resp in chat_mode:
if "choices" in resp:
for c in resp['choices']:
for c in resp.choices:

# move content from delta to text so we have a consistent interface with non-chat mode
found_content = False
if "content" in c['delta'] and c['delta']['content'] != "":
if "content" in c.delta and c.delta.content != "":
found_content = True
c['text'] = c['delta']['content']
c.text = c.delta.delta.content

# capture function call data and convert to text again so we have a consistent interface with non-chat mode and open models
if "function_call" in c['delta']:
if "function_call" in c.delta:

# build the start of the function call (the follows the syntax that GPT says it wants when we ask it, and will be parsed by the @function_detector)
if not in_function_call:
start_val = "\n```typescript\nfunctions."+c['delta']['function_call']["name"]+"("
if not c['text']:
c['text'] = start_val
start_val = "\n```typescript\nfunctions."+c.delta.function_call.name+"("
if not c.text:
c.text = start_val
else:
c['text'] += start_val
c.text += start_val
in_function_call = True

# extend the arguments JSON string
val = c['delta']['function_call']["arguments"]
val = c.delta.function_call.arguments
if 'text' in c:
c['text'] += val
c.text += val
else:
c['text'] = val
c.text = val

if not found_content and not in_function_call:
break # the role markers are outside the generation in chat mode right now TODO: consider how this changes for uncontrained generation
Expand All @@ -106,13 +106,21 @@ async def add_text_to_chat_mode_generator(chat_mode):
if in_function_call:
yield {'choices': [{'text': ')```'}]}

def add_text_to_chat_mode(chat_mode):
import types
import asyncio

async def add_text_to_chat_mode(chat_mode):
if isinstance(chat_mode, (types.AsyncGeneratorType, types.GeneratorType)):
return add_text_to_chat_mode_generator(chat_mode)
elif isinstance(chat_mode, asyncio.StreamReader):
async for c in chat_mode:
c.text = c.message.content
# for c in chat_mode.choices:
# c.text = c.message.content
# return chat_mode
else:
for c in chat_mode['choices']:
c['text'] = c['message']['content']
return chat_mode
return add_text_to_chat_mode_generator(chat_mode)


class OpenAI(LLM):
llm_name: str = "openai"
Expand Down Expand Up @@ -254,8 +262,8 @@ async def stream_then_save(cls, gen, key, stop_regex, n):
if stop_regex is not None:

# keep track of the generated text so far
for i,choice in enumerate(curr_out['choices']):
current_strings[i] += choice['text']
for i,choice in enumerate(curr_out.choices):
current_strings[i] += choice.text

# check if all of the strings match a stop string (and hence we can stop the batch inference)
all_done = True
Expand Down Expand Up @@ -296,12 +304,12 @@ async def stream_then_save(cls, gen, key, stop_regex, n):
cached_out = None

if stop_regex is not None:
for i in range(len(out['choices'])):
if stop_pos[i] < len(out['choices'][i]['text']):
out['choices'][i] = out['choices'][i].to_dict() # because sometimes we might need to set the text to the empty string (and OpenAI's object does not like that)
out['choices'][i]['text'] = out['choices'][i]['text'][:stop_pos[i]]
out['choices'][i]['stop_text'] = stop_text[i]
out['choices'][i]['finish_reason'] = "stop"
for i in range(len(out.choices)):
if stop_pos[i] < len(out.choices[i].text):
out.choices[i] = out.choices[i].to_dict() # because sometimes we might need to set the text to the empty string (and OpenAI's object does not like that)
out.choices[i].text = out.choices[i].text[:stop_pos[i]]
out.choices[i].stop_text = stop_text[i]
out.choices[i].finish_reason = "stop"

list_out.append(out)
yield out
Expand Down Expand Up @@ -342,44 +350,31 @@ async def _library_call(self, **kwargs):
Note that is uses the local auth token, and does not rely on the openai one.
"""

# save the params of the openai library
prev_key = openai.api_key
prev_org = openai.organization
prev_type = openai.api_type
prev_version = openai.api_version
prev_base = openai.api_base

# set the params of the openai library if we have them
if self.api_key is not None:
openai.api_key = self.api_key
if self.organization is not None:
openai.organization = self.organization
if self.api_type is not None:
openai.api_type = self.api_type
if self.api_version is not None:
openai.api_version = self.api_version
if self.api_base is not None:
openai.api_base = self.api_base

assert openai.api_key is not None, "You must provide an OpenAI API key to use the OpenAI LLM. Either pass it in the constructor, set the OPENAI_API_KEY environment variable, or create the file ~/.openai_api_key with your key in it."

api_key = self.api_key or openai.api_key
if api_key is None:
raise Exception("You must provide an OpenAI API key to use the OpenAI LLM. Either pass it in the constructor, set the OPENAI_API_KEY environment variable, or create the file ~/.openai_api_key with your key in it.")

# Instantiate the AsyncOpenAI client
client = AsyncOpenAI(
api_key=self.api_key or openai.api_key,
organization=self.organization or openai.organization,

# TODO: Implement AzureOpenAI,which implements the Azure client
# api_type=self.api_type or openai.api_type,
# api_version=self.api_version or openai.api_version,
base_url=self.api_base
)

if self.chat_mode:
kwargs['messages'] = prompt_to_messages(kwargs['prompt'])
del kwargs['prompt']
del kwargs['echo']
del kwargs['logprobs']
# print(kwargs)
out = await openai.ChatCompletion.acreate(**kwargs)
out = add_text_to_chat_mode(out)
out = await client.chat.completions.create(**kwargs)
out = await add_text_to_chat_mode(out)
else:
out = await openai.Completion.acreate(**kwargs)

# restore the params of the openai library
openai.api_key = prev_key
openai.organization = prev_org
openai.api_type = prev_type
openai.api_version = prev_version
openai.api_base = prev_base
out = await client.completions.create(**kwargs)

return out

Expand Down Expand Up @@ -410,7 +405,7 @@ async def _rest_call(self, **kwargs):
data['messages'] = prompt_to_messages(data['prompt'])
del data['prompt']
del data['echo']
del data['logprobs']
del data.logprobs

# Send a POST request and get the response
# An exception for timeout is raised if the server has not issued a response for 10 seconds
Expand Down Expand Up @@ -470,18 +465,18 @@ def merge_stream_chunks(first_chunk, second_chunk):
out = copy.deepcopy(first_chunk)

# merge the choices
for i in range(len(out['choices'])):
out_choice = out['choices'][i]
second_choice = second_chunk['choices'][i]
out_choice['text'] += second_choice['text']
for i in range(len(out.choices)):
out_choice = out.choices[i]
second_choice = second_chunk.choices[i]
out_choice.text += second_choice.text
if 'index' in second_choice:
out_choice['index'] = second_choice['index']
out_choice.index = second_choice.index
if 'finish_reason' in second_choice:
out_choice['finish_reason'] = second_choice['finish_reason']
out_choice.finish_reason = second_choice.finish_reason
if out_choice.get('logprobs', None) is not None:
out_choice['logprobs']['token_logprobs'] += second_choice['logprobs']['token_logprobs']
out_choice['logprobs']['top_logprobs'] += second_choice['logprobs']['top_logprobs']
out_choice['logprobs']['text_offset'] = second_choice['logprobs']['text_offset']
out_choice.logprobs.token_logprobs += second_choice.logprobs.token_logprobs
out_choice.logprobs.top_logprobs += second_choice.logprobs.top_logprobs
out_choice.logprobs.text_offset = second_choice.logprobs.text_offset

return out

Expand Down Expand Up @@ -645,7 +640,8 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n
self.llm.add_call()
call_args = {
"model": self.llm.model_name,
"deployment_id": self.llm.deployment_id,
# TODO: Move deployment_id to AzureOpenAI LLM implementation
# "deployment_id": self.llm.deployment_id,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
Expand All @@ -666,7 +662,7 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n
call_args["logit_bias"] = {str(k): v for k,v in logit_bias.items()} # convert keys to strings since that's the open ai api's format
out = await self.llm.caller(**call_args)

except (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError, openai.error.Timeout):
except (RateLimitError, APIError, APITimeoutError, APIStatusError) as e:
await asyncio.sleep(3)
try_again = True
fail_count += 1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def find_version(*file_paths):
install_requires=[
"diskcache",
"gptcache",
"openai>=0.27.8",
"openai>=1.0.0",
"pyparsing>=3.0.0",
"pygtrie",
"platformdirs",
Expand Down