Skip to content

Commit 1a7016f

Browse files
author
Chris Cummins
committed
Rebase fixes for service pool PR.
1 parent 2210f50 commit 1a7016f

File tree

2 files changed

+33
-250
lines changed

2 files changed

+33
-250
lines changed

compiler_gym/envs/compiler_env.py

Lines changed: 0 additions & 238 deletions
Original file line numberDiff line numberDiff line change
@@ -65,244 +65,6 @@ def observation_space_spec(self) -> ObservationSpaceSpec:
6565
def observation_space_spec(
6666
self, observation_space_spec: Optional[ObservationSpaceSpec]
6767
):
68-
<<<<<<< HEAD
69-
=======
70-
"""Construct and initialize a CompilerGym environment.
71-
72-
In normal use you should use :code:`gym.make(...)` rather than calling
73-
the constructor directly.
74-
75-
:param service: The hostname and port of a service that implements the
76-
CompilerGym service interface, or the path of a binary file which
77-
provides the CompilerGym service interface when executed. See
78-
:doc:`/compiler_gym/service` for details.
79-
80-
:param rewards: The reward spaces that this environment supports.
81-
Rewards are typically calculated based on observations generated by
82-
the service. See :class:`Reward <compiler_gym.spaces.Reward>` for
83-
details.
84-
85-
:param benchmark: The benchmark to use for this environment. Either a
86-
URI string, or a :class:`Benchmark
87-
<compiler_gym.datasets.Benchmark>` instance. If not provided, the
88-
first benchmark as returned by
89-
:code:`next(env.datasets.benchmarks())` will be used as the default.
90-
91-
:param observation_space: Compute and return observations at each
92-
:func:`step()` from this space. Accepts a string name or an
93-
:class:`ObservationSpaceSpec
94-
<compiler_gym.views.ObservationSpaceSpec>`. If not provided,
95-
:func:`step()` returns :code:`None` for the observation value. Can
96-
be set later using :meth:`env.observation_space
97-
<compiler_gym.envs.CompilerEnv.observation_space>`. For available
98-
spaces, see :class:`env.observation.spaces
99-
<compiler_gym.views.ObservationView>`.
100-
101-
:param reward_space: Compute and return reward at each :func:`step()`
102-
from this space. Accepts a string name or a :class:`Reward
103-
<compiler_gym.spaces.Reward>`. If not provided, :func:`step()`
104-
returns :code:`None` for the reward value. Can be set later using
105-
:meth:`env.reward_space
106-
<compiler_gym.envs.CompilerEnv.reward_space>`. For available spaces,
107-
see :class:`env.reward.spaces <compiler_gym.views.RewardView>`.
108-
109-
:param action_space: The name of the action space to use. If not
110-
specified, the default action space for this compiler is used.
111-
112-
:param derived_observation_spaces: An optional list of arguments to be
113-
passed to :meth:`env.observation.add_derived_space()
114-
<compiler_gym.views.observation.Observation.add_derived_space>`.
115-
116-
:param connection_settings: The settings used to establish a connection
117-
with the remote service.
118-
119-
:param service_connection: An existing compiler gym service connection
120-
to use.
121-
122-
:param service_pool: A service pool to use for acquiring a service
123-
connection. If not specified, the :meth:`global service pool
124-
<compiler_gym.service.ServiceConnectionPool.get>` is used.
125-
126-
:raises FileNotFoundError: If service is a path to a file that is not
127-
found.
128-
129-
:raises TimeoutError: If the compiler service fails to initialize within
130-
the parameters provided in :code:`connection_settings`.
131-
"""
132-
# NOTE(cummins): Logger argument deprecated and scheduled to be removed
133-
# in release 0.2.3.
134-
if logger:
135-
warnings.warn(
136-
"The `logger` argument is deprecated on CompilerEnv.__init__() "
137-
"and will be removed in a future release. All CompilerEnv "
138-
"instances share a logger named compiler_gym.envs.compiler_env",
139-
DeprecationWarning,
140-
)
141-
142-
self.metadata = {"render.modes": ["human", "ansi"]}
143-
144-
# A compiler service supports multiple simultaneous environments. This
145-
# session ID is used to identify this environment.
146-
self._session_id: Optional[int] = None
147-
148-
self._service_endpoint: Union[str, Path] = service
149-
self._connection_settings = connection_settings or ConnectionOpts()
150-
151-
if service_connection is None:
152-
self._service_pool: Optional[ServiceConnectionPool] = (
153-
ServiceConnectionPool.get() if service_pool is None else service_pool
154-
)
155-
self.service = self._service_pool.acquire(
156-
endpoint=self._service_endpoint,
157-
opts=self._connection_settings,
158-
)
159-
else:
160-
self._service_pool: Optional[ServiceConnectionPool] = service_pool
161-
self.service = service_connection
162-
163-
self.datasets = Datasets(datasets or [])
164-
165-
self.action_space_name = action_space
166-
167-
# If no reward space is specified, generate some from numeric observation spaces
168-
rewards = rewards or [
169-
DefaultRewardFromObservation(obs.name)
170-
for obs in self.service.observation_spaces
171-
if obs.default_observation.WhichOneof("value")
172-
and isinstance(
173-
getattr(
174-
obs.default_observation, obs.default_observation.WhichOneof("value")
175-
),
176-
numbers.Number,
177-
)
178-
]
179-
180-
# The benchmark that is currently being used, and the benchmark that
181-
# will be used on the next call to reset(). These are equal except in
182-
# the gap between the user setting the env.benchmark property while in
183-
# an episode and the next call to env.reset().
184-
self._benchmark_in_use: Optional[Benchmark] = None
185-
self._benchmark_in_use_proto: BenchmarkProto = BenchmarkProto()
186-
self._next_benchmark: Optional[Benchmark] = None
187-
# Normally when the benchmark is changed the updated value is not
188-
# reflected until the next call to reset(). We make an exception for the
189-
# constructor-time benchmark as otherwise the behavior of the benchmark
190-
# property is counter-intuitive:
191-
#
192-
# >>> env = gym.make("example-v0", benchmark="foo")
193-
# >>> env.benchmark
194-
# None
195-
# >>> env.reset()
196-
# >>> env.benchmark
197-
# "foo"
198-
#
199-
# By forcing the _benchmark_in_use URI at constructor time, the first
200-
# env.benchmark above returns the benchmark as expected.
201-
try:
202-
self.benchmark = benchmark or next(self.datasets.benchmarks())
203-
self._benchmark_in_use = self._next_benchmark
204-
except StopIteration:
205-
# StopIteration raised on next(self.datasets.benchmarks()) if there
206-
# are no benchmarks available. This is to allow CompilerEnv to be
207-
# used without any datasets by setting a benchmark before/during the
208-
# first reset() call.
209-
pass
210-
211-
# Process the available action, observation, and reward spaces.
212-
self.action_spaces = [
213-
proto_to_action_space(space) for space in self.service.action_spaces
214-
]
215-
216-
self.observation = self._observation_view_type(
217-
raw_step=self.raw_step,
218-
spaces=self.service.observation_spaces,
219-
)
220-
self.reward = self._reward_view_type(rewards, self.observation)
221-
222-
# Register any derived observation spaces now so that the observation
223-
# space can be set below.
224-
for derived_observation_space in derived_observation_spaces or []:
225-
self.observation.add_derived_space_internal(**derived_observation_space)
226-
227-
# Lazily evaluated version strings.
228-
self._versions: Optional[GetVersionReply] = None
229-
230-
self.action_space: Optional[Space] = None
231-
self.observation_space: Optional[Space] = None
232-
233-
# Mutable state initialized in reset().
234-
self.reward_range: Tuple[float, float] = (-np.inf, np.inf)
235-
self.episode_reward: Optional[float] = None
236-
self.episode_start_time: float = time()
237-
self.actions: List[ActionType] = []
238-
239-
# Initialize the default observation/reward spaces.
240-
self.observation_space_spec: Optional[ObservationSpaceSpec] = None
241-
self.reward_space_spec: Optional[Reward] = None
242-
self.observation_space = observation_space
243-
self.reward_space = reward_space
244-
245-
@property
246-
@deprecated(
247-
version="0.2.1",
248-
reason=(
249-
"The `CompilerEnv.logger` attribute is deprecated. All CompilerEnv "
250-
"instances share a logger named compiler_gym.envs.compiler_env"
251-
),
252-
)
253-
def logger(self):
254-
return _logger
255-
256-
@property
257-
def versions(self) -> GetVersionReply:
258-
"""Get the version numbers from the compiler service."""
259-
if self._versions is None:
260-
self._versions = self.service(
261-
self.service.stub.GetVersion, GetVersionRequest()
262-
)
263-
return self._versions
264-
265-
@property
266-
def version(self) -> str:
267-
"""The version string of the compiler service."""
268-
return self.versions.service_version
269-
270-
@property
271-
def compiler_version(self) -> str:
272-
"""The version string of the underlying compiler that this service supports."""
273-
return self.versions.compiler_version
274-
275-
def commandline(self) -> str:
276-
"""Interface for :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`
277-
subclasses to provide an equivalent commandline invocation to the
278-
current environment state.
279-
280-
See also :meth:`commandline_to_actions()
281-
<compiler_gym.envs.CompilerEnv.commandline_to_actions>`.
282-
283-
Calling this method on a :class:`CompilerEnv
284-
<compiler_gym.envs.CompilerEnv>` instance raises
285-
:code:`NotImplementedError`.
286-
287-
:return: A string commandline invocation.
288-
"""
289-
raise NotImplementedError("abstract method")
290-
291-
def commandline_to_actions(self, commandline: str) -> List[ActionType]:
292-
"""Interface for :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`
293-
subclasses to convert from a commandline invocation to a sequence of
294-
actions.
295-
296-
See also :meth:`commandline()
297-
<compiler_gym.envs.CompilerEnv.commandline>`.
298-
299-
Calling this method on a :class:`CompilerEnv
300-
<compiler_gym.envs.CompilerEnv>` instance raises
301-
:code:`NotImplementedError`.
302-
303-
:return: A list of actions.
304-
"""
305-
>>>>>>> 4a874cee (Fix typo in docstring.)
30668
raise NotImplementedError("abstract method")
30769

30870
@property

compiler_gym/service/client_service_compiler_env.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from time import time
1515
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
1616

17-
from compiler_gym.service.connection_pool import ServiceConnectionPool
1817
import numpy as np
1918
from deprecated.sphinx import deprecated
2019
from gym.spaces import Space
@@ -32,6 +31,10 @@
3231
SessionNotFound,
3332
)
3433
from compiler_gym.service.connection import ServiceIsClosed
34+
from compiler_gym.service.connection_pool import (
35+
ServiceConnectionPool,
36+
ServiceConnectionPoolBase,
37+
)
3538
from compiler_gym.service.proto import ActionSpace, AddBenchmarkRequest
3639
from compiler_gym.service.proto import Benchmark as BenchmarkProto
3740
from compiler_gym.service.proto import (
@@ -136,6 +139,7 @@ def __init__(
136139
reward_space: Optional[Union[str, Reward]] = None,
137140
action_space: Optional[str] = None,
138141
derived_observation_spaces: Optional[List[Dict[str, Any]]] = None,
142+
service_message_converters: ServiceMessageConverters = None,
139143
connection_settings: Optional[ConnectionOpts] = None,
140144
service_connection: Optional[CompilerGymServiceConnection] = None,
141145
service_pool: Optional[ServiceConnectionPool] = None,
@@ -187,6 +191,9 @@ def __init__(
187191
passed to :meth:`env.observation.add_derived_space()
188192
<compiler_gym.views.observation.Observation.add_derived_space>`.
189193
194+
:param service_message_converters: Custom converters for action spaces
195+
and actions.
196+
190197
:param connection_settings: The settings used to establish a connection
191198
with the remote service.
192199
@@ -234,7 +241,7 @@ def __init__(
234241
self._service_pool: Optional[ServiceConnectionPoolBase] = service_pool
235242
self.service = service_connection
236243

237-
self.datasets = Datasets(datasets or [])
244+
self._datasets = Datasets(datasets or [])
238245

239246
self.action_space_name = action_space
240247

@@ -282,9 +289,16 @@ def __init__(
282289
# first reset() call.
283290
pass
284291

292+
self.service_message_converters = (
293+
ServiceMessageConverters()
294+
if service_message_converters is None
295+
else service_message_converters
296+
)
297+
285298
# Process the available action, observation, and reward spaces.
286299
self.action_spaces = [
287-
proto_to_action_space(space) for space in self.service.action_spaces
300+
self.service_message_converters.action_space_converter(space)
301+
for space in self.service.action_spaces
288302
]
289303

290304
self.observation = self._observation_view_type(
@@ -308,7 +322,7 @@ def __init__(
308322
self.reward_range: Tuple[float, float] = (-np.inf, np.inf)
309323
self.episode_reward: Optional[float] = None
310324
self.episode_start_time: float = time()
311-
self.actions: List[ActionType] = []
325+
self._actions: List[ActionType] = []
312326

313327
# Initialize the default observation/reward spaces.
314328
self.observation_space_spec: Optional[ObservationSpaceSpec] = None
@@ -544,10 +558,11 @@ def _init_kwargs(self) -> Dict[str, Any]:
544558
"benchmark": self.benchmark,
545559
"connection_settings": self._connection_settings,
546560
"service": self._service_endpoint,
561+
"service_pool": self._service_pool,
547562
}
548563

549564
def fork(self) -> "ClientServiceCompilerEnv":
550-
if not self.in_episode:
565+
if not self.in_episode:
551566
actions = self.actions.copy()
552567
self.reset()
553568
if actions:
@@ -603,7 +618,7 @@ def fork(self) -> "ClientServiceCompilerEnv":
603618
# Copy over the mutable episode state.
604619
new_env.episode_reward = self.episode_reward
605620
new_env.episode_start_time = self.episode_start_time
606-
new_env.actions = self.actions.copy()
621+
new_env._actions = self.actions.copy() # pylint: disable=protected-access
607622

608623
return new_env
609624

@@ -698,7 +713,7 @@ def reset( # pylint: disable=arguments-differ
698713
does not have a default benchmark to select from.
699714
"""
700715

701-
def _retry(error) -> Optional[ObservationType]:
716+
def _retry(error) -> Optional[ObservationType]:
702717
"""Abort and retry on error."""
703718
# Log the error that we are recovering from, but treat
704719
# ServiceIsClosed errors as unimportant since we know what causes
@@ -837,11 +852,13 @@ def _call_with_error(
837852
self.observation.session_id = reply.session_id
838853
self.reward.get_cost = self.observation.__getitem__
839854
self.episode_start_time = time()
840-
self.actions = []
855+
self._actions: List[ActionType] = []
841856

842857
# If the action space has changed, update it.
843858
if reply.HasField("new_action_space"):
844-
self.action_space = proto_to_action_space(reply.new_action_space)
859+
self.action_space = self.service_message_converters.action_space_converter(
860+
reply.new_action_space
861+
)
845862

846863
self.reward.reset(benchmark=self.benchmark, observation_view=self.observation)
847864
if self.reward_space:
@@ -905,12 +922,14 @@ def raw_step(
905922
}
906923

907924
# Record the actions.
908-
self.actions += actions
925+
self._actions += actions
909926

910927
# Send the request to the backend service.
911928
request = StepRequest(
912929
session_id=self._session_id,
913-
action=[Event(int64_value=a) for a in actions],
930+
action=[
931+
self.service_message_converters.action_converter(a) for a in actions
932+
],
914933
observation_space=[
915934
observation_space.index for observation_space in observations_to_compute
916935
],
@@ -954,7 +973,9 @@ def raw_step(
954973

955974
# If the action space has changed, update it.
956975
if reply.HasField("new_action_space"):
957-
self.action_space = proto_to_action_space(reply.new_action_space)
976+
self.action_space = self.service_message_converters.action_space_converter(
977+
reply.new_action_space
978+
)
958979

959980
# Translate observations to python representations.
960981
if len(reply.observation) != len(observations_to_compute):

0 commit comments

Comments
 (0)