Skip to content

Commit bc57994

Browse files
committed
Pin tensorforce and add test
1 parent e50da21 commit bc57994

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ RUN pip install flashtext && \
371371
pip install shap && \
372372
pip install ray && \
373373
pip install gym && \
374-
pip install tensorforce && \
374+
# b/167268016 tensorforce 0.6.6 has an explicit dependency on tensorflow 2.3.1 which is causing a downgrade.
375+
pip install tensorforce==0.5.5 && \
375376
pip install pyarabic && \
376377
pip install pandasql && \
377378
pip install tensorflow_hub && \

tests/test_tensorforce.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import unittest
2+
3+
from tensorforce import Agent, Environment
4+
5+
class TestTensorforce(unittest.TestCase):
6+
# based on https://github.com/tensorforce/tensorforce/tree/master#quickstart-example-code.
7+
def test_quickstart(self):
8+
environment = Environment.create(
9+
environment='gym', level='CartPole', max_episode_timesteps=500
10+
)
11+
12+
agent = Agent.create(
13+
agent='tensorforce',
14+
environment=environment, # alternatively: states, actions, (max_episode_timesteps)
15+
memory=1000,
16+
update=dict(unit='timesteps', batch_size=32),
17+
optimizer=dict(type='adam', learning_rate=3e-4),
18+
policy=dict(network='auto'),
19+
objective='policy_gradient',
20+
reward_estimation=dict(horizon=1)
21+
)
22+
23+
# Train for a single episode.
24+
states = environment.reset()
25+
actions = agent.act(states=states)
26+
states, terminal, reward = environment.execute(actions=actions)
27+
28+
self.assertEqual(4, len(states))
29+
self.assertFalse(terminal)
30+
self.assertEqual(1, reward)

0 commit comments

Comments
 (0)