diff --git a/sc2/bot_ai.py b/sc2/bot_ai.py index c5f6cc205..b5dedb43d 100644 --- a/sc2/bot_ai.py +++ b/sc2/bot_ai.py @@ -133,7 +133,7 @@ def main_base_ramp(self) -> "Ramp": ParaSite map has 5 upper points, and most other maps have 2 upper points at the main ramp. The map Acolyte has 4 upper points at the wrong ramp (which is closest to the start position) """ self.cached_main_base_ramp = min( (ramp for ramp in self.game_info.map_ramps if len(ramp.upper) in {2, 5}), - key=lambda r: self.start_location._distance_squared(r.top_center), + key=lambda r: self.start_location.distance_to(r.top_center), ) return self.cached_main_base_ramp @@ -148,7 +148,7 @@ def expansion_locations(self) -> Dict[Point2, Units]: # any resource in a group is closer than 6 to any resource of another group # Distance we group resources by - RESOURCE_SPREAD_THRESHOLD = 36 + RESOURCE_SPREAD_THRESHOLD = 8.5 minerals = self.state.mineral_field geysers = self.state.vespene_geyser all_resources = minerals | geysers @@ -164,7 +164,7 @@ def expansion_locations(self) -> Dict[Point2, Units]: for group_a, group_b in itertools.combinations(resource_groups, 2): # Check if any pair of resource of these groups is closer than threshold together if any( - resource_a.position._distance_squared(resource_b.position) <= RESOURCE_SPREAD_THRESHOLD + resource_a.distance_to(resource_b) <= RESOURCE_SPREAD_THRESHOLD for resource_a, resource_b in itertools.product(group_a, group_b) ): # Remove the single groups and add the merged group @@ -179,7 +179,7 @@ def expansion_locations(self) -> Dict[Point2, Units]: (x, y) for x in range(-offset_range, offset_range + 1) for y in range(-offset_range, offset_range + 1) - if 49 >= x ** 2 + y ** 2 >= 16 + if 4 <= math.hypot(x, y) <= 7 ] # Dict we want to return centers = {} @@ -199,17 +199,11 @@ def expansion_locations(self) -> Dict[Point2, Units]: # Check if point can be built on if self._game_info.placement_grid[point.rounded] != 0 # Check if all resources have enough space to point - and all( - point._distance_squared(resource.position) >= (49 if resource in geysers else 36) - for resource in resources - ) + and all(point.distance_to(resource) >= (7 if resource in geysers else 6) for resource in resources) ) # Choose best fitting point # TODO can we improve this by calculating the distance only one time? - result = min( - possible_points, - key=lambda point: sum(point._distance_squared(resource.position) for resource in resources), - ) + result = min(possible_points, key=lambda point: sum(point.distance_to(resource) for resource in resources)) centers[result] = resources return centers @@ -271,7 +265,7 @@ async def get_next_expansion(self) -> Optional[Point2]: for el in self.expansion_locations: def is_near_to_expansion(t): - return t.position._distance_squared(el) < self.EXPANSION_GAP_THRESHOLD ** 2 + return t.distance_to(el) < self.EXPANSION_GAP_THRESHOLD if any(map(is_near_to_expansion, self.townhalls)): # already taken @@ -288,16 +282,27 @@ def is_near_to_expansion(t): return closest - async def distribute_workers(self): + async def distribute_workers(self, resource_ratio: float = 2): """ Distributes workers across all the bases taken. + Keyword `resource_ratio` takes a float. If the current minerals to gas + ratio is bigger than `resource_ratio`, this function prefer filling geysers + first, if it is lower, it will prefer sending workers to minerals first. + This is only for workers that need to be moved anyways, it will NOT will + geysers on its own. + + NOTE: This function is far from optimal, if you really want to have + refined worker control, you should write your own distribution function. + For example long distance mining control and moving workers if a base was killed + are not being handled. + WARNING: This is quite slow when there are lots of workers or multiple bases. """ - if not self.state.mineral_field or not self.workers or not self.owned_expansions.ready: + if not self.state.mineral_field or not self.workers or not self.townhalls.ready: return actions = [] worker_pool = [worker for worker in self.workers.idle] - bases = self.owned_expansions.ready + bases = self.townhalls.ready geysers = self.geysers.ready # list of places that need more workers @@ -308,7 +313,7 @@ async def distribute_workers(self): # perfect amount of workers, skip mining place if not difference: continue - if mining_place.has_vespene: + if mining_place.is_vespene_geyser: # get all workers that target the gas extraction site # or are on their way back from it local_workers = self.workers.filter( @@ -316,45 +321,71 @@ async def distribute_workers(self): or (unit.is_carrying_vespene and unit.order_target == bases.closest_to(mining_place).tag) ) else: - # get minerals around expansion - local_minerals = self.expansion_locations[mining_place.position].filter( - lambda resource: resource.has_minerals - ) + # get tags of minerals around expansion + local_minerals_tags = { + mineral.tag for mineral in self.state.mineral_field if mineral.distance_to(mining_place) <= 8 + } # get all target tags a worker can have # tags of the minerals he could mine at that base # get workers that work at that gather site local_workers = self.workers.filter( - lambda unit: unit.order_target in local_minerals.tags + lambda unit: unit.order_target in local_minerals_tags or (unit.is_carrying_minerals and unit.order_target == mining_place.tag) ) # too many workers if difference > 0: - worker_pool.append(local_workers[:difference]) + for worker in local_workers[:difference]: + worker_pool.append(worker) # too few workers # add mining place to deficit bases for every missing worker else: deficit_mining_places += [mining_place for _ in range(-difference)] + # prepare all minerals near a base if we have too many workers + # and need to send them to the closest patch + if len(worker_pool) > len(deficit_mining_places): + all_minerals_near_base = [ + mineral + for mineral in self.state.mineral_field + if any(mineral.distance_to(base) <= 8 for base in self.townhalls.ready) + ] # distribute every worker in the pool for worker in worker_pool: # as long as have workers and mining places if deficit_mining_places: - # remove current place from the list for next loop - current_place = deficit_mining_places.pop(0) + # choose only mineral fields first if current mineral to gas ratio is less than target ratio + if self.vespene and self.minerals / self.vespene < resource_ratio: + possible_mining_places = [place for place in deficit_mining_places if not place.vespene_contents] + # else prefer gas + else: + possible_mining_places = [place for place in deficit_mining_places if place.vespene_contents] + # if preferred type is not available any more, get all other places + if not possible_mining_places: + possible_mining_places = deficit_mining_places + # find closest mining place + current_place = min(deficit_mining_places, key=lambda place: place.distance_to(worker)) + # remove it from the list + deficit_mining_places.remove(current_place) # if current place is a gas extraction site, go there - if current_place.has_vespene: + if current_place.vespene_contents: actions.append(worker.gather(current_place)) # if current place is a gas extraction site, # go to the mineral field that is near and has the most minerals left else: - local_minerals = self.expansion_locations[current_place.position].filter( - lambda resource: resource.has_minerals - ) + local_minerals = [ + mineral for mineral in self.state.mineral_field if mineral.distance_to(current_place) <= 8 + ] target_mineral = max(local_minerals, key=lambda mineral: mineral.mineral_contents) actions.append(worker.gather(target_mineral)) # more workers to distribute than free mining spots - # else: - # pass + # send to closest if worker is doing nothing + elif worker.is_idle and all_minerals_near_base: + target_mineral = min(all_minerals_near_base, key=lambda mineral: mineral.distance_to(worker)) + actions.append(worker.gather(target_mineral)) + else: + # there are no deficit mining places and worker is not idle + # so dont move him + pass await self.do_actions(actions) @@ -366,7 +397,7 @@ def owned_expansions(self) -> Dict[Point2, Unit]: for el in self.expansion_locations: def is_near_to_expansion(t): - return t.position._distance_squared(el) < self.EXPANSION_GAP_THRESHOLD ** 2 + return t.distance_to(el) < self.EXPANSION_GAP_THRESHOLD th = next((x for x in self.townhalls if is_near_to_expansion(x)), None) if th: @@ -425,21 +456,21 @@ async def can_cast( ability_target == 1 or ability_target == Target.PointOrNone.value and isinstance(target, (Point2, Point3)) - and unit.position._distance_squared(target.position) <= cast_range ** 2 + and unit.distance_to(target) <= cast_range ): # cant replace 1 with "Target.None.value" because ".None" doesnt seem to be a valid enum name return True # Check if able to use ability on a unit elif ( ability_target in {Target.Unit.value, Target.PointOrUnit.value} and isinstance(target, Unit) - and unit.position._distance_squared(target.position) <= cast_range ** 2 + and unit.distance_to(target) <= cast_range ): return True # Check if able to use ability on a position elif ( ability_target in {Target.Point.value, Target.PointOrUnit.value} and isinstance(target, (Point2, Point3)) - and unit.position._distance_squared(target) <= cast_range ** 2 + and unit.distance_to(target) <= cast_range ): return True return False @@ -514,7 +545,7 @@ async def find_placement( if random_alternative: return random.choice(possible) else: - return min(possible, key=lambda p: p._distance_squared(near)) + return min(possible, key=lambda p: p.distance_to_point2(near)) return None def already_pending_upgrade(self, upgrade_type: UpgradeId) -> Union[int, float]: diff --git a/sc2/cache.py b/sc2/cache.py index 195c3407b..1fa9ee49e 100644 --- a/sc2/cache.py +++ b/sc2/cache.py @@ -20,7 +20,7 @@ def property_cache_once_per_frame(f): then clears it if it is accessed in a different game loop. Only works on properties of the bot object, because it requires access to self.state.game_loop """ - + @wraps(f) def inner(self): property_cache = "_cache_" + f.__name__ diff --git a/sc2/game_info.py b/sc2/game_info.py index 36df54524..c23a16ee9 100644 --- a/sc2/game_info.py +++ b/sc2/game_info.py @@ -57,7 +57,7 @@ def upper2_for_ramp_wall(self) -> Set[Point2]: return set() # HACK: makes this work for now # FIXME: please do - upper2 = sorted(list(self.upper), key=lambda x: x._distance_squared(self.bottom_center), reverse=True) + upper2 = sorted(list(self.upper), key=lambda x: x.distance_to_point2(self.bottom_center), reverse=True) while len(upper2) > 2: upper2.pop() return set(upper2) @@ -99,7 +99,7 @@ def barracks_in_middle(self) -> Point2: # Offset from top point to barracks center is (2, 1) intersects = p1.circle_intersection(p2, 5 ** 0.5) anyLowerPoint = next(iter(self.lower)) - return max(intersects, key=lambda p: p._distance_squared(anyLowerPoint)) + return max(intersects, key=lambda p: p.distance_to_point2(anyLowerPoint)) raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") @property_immutable_cache @@ -112,7 +112,7 @@ def depot_in_middle(self) -> Point2: # Offset from top point to depot center is (1.5, 0.5) intersects = p1.circle_intersection(p2, 2.5 ** 0.5) anyLowerPoint = next(iter(self.lower)) - return max(intersects, key=lambda p: p._distance_squared(anyLowerPoint)) + return max(intersects, key=lambda p: p.distance_to_point2(anyLowerPoint)) raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") @property_mutable_cache @@ -122,7 +122,7 @@ def corner_depots(self) -> Set[Point2]: points = self.upper2_for_ramp_wall p1 = points.pop().offset((self.x_offset, self.y_offset)) # still an error with pixelmap? p2 = points.pop().offset((self.x_offset, self.y_offset)) - center = p1.towards(p2, p1.distance_to(p2) / 2) + center = p1.towards(p2, p1.distance_to_point2(p2) / 2) depotPosition = self.depot_in_middle # Offset from middle depot to corner depots is (2, 1) intersects = center.circle_intersection(depotPosition, 5 ** 0.5) @@ -224,7 +224,7 @@ def paint(pt: Point2) -> None: if picture[py][px] != NOT_COLORED_YET: continue point: Point2 = Point2((px, py)) - remaining.remove(point) + remaining.discard(point) paint(point) queue.append(point) currentGroup.add(point) diff --git a/sc2/game_state.py b/sc2/game_state.py index fac7b4999..d9d293de7 100644 --- a/sc2/game_state.py +++ b/sc2/game_state.py @@ -8,6 +8,7 @@ from .position import Point2, Point3 from .power_source import PsionicMatrix from .score import ScoreDetails +from .unit import Unit from .units import Units @@ -124,53 +125,53 @@ def __init__(self, response_observation): # https://github.com/Blizzard/s2client-proto/blob/33f0ecf615aa06ca845ffe4739ef3133f37265a9/s2clientprotocol/score.proto#L31 self.score: ScoreDetails = ScoreDetails(self.observation.score) self.abilities = self.observation.abilities # abilities of selected units - # Fix for enemy units detected by my sensor tower, as blips have less unit information than normal visible units - blipUnits, minerals, geysers, destructables, enemy, own, watchtowers = ([] for _ in range(7)) + + self._blipUnits = [] + self.own_units: Units = Units([]) + self.enemy_units: Units = Units([]) + self.mineral_field: Units = Units([]) + self.vespene_geyser: Units = Units([]) + self.resources: Units = Units([]) + self.destructables: Units = Units([]) + self.watchtowers: Units = Units([]) + self.units: Units = Units([]) for unit in self.observation_raw.units: if unit.is_blip: - blipUnits.append(unit) + self._blipUnits.append(unit) else: + unit_obj = Unit(unit) + self.units.append(unit_obj) alliance = unit.alliance # Alliance.Neutral.value = 3 if alliance == 3: unit_type = unit.unit_type # XELNAGATOWER = 149 if unit_type == 149: - watchtowers.append(unit) - # all destructable rocks except the one below the main base ramps - elif unit.radius > 1.5: - destructables.append(unit) + self.watchtowers.append(unit_obj) # mineral field enums elif unit_type in mineral_ids: - minerals.append(unit) + self.mineral_field.append(unit_obj) + self.resources.append(unit_obj) # geyser enums elif unit_type in geyser_ids: - geysers.append(unit) + self.vespene_geyser.append(unit_obj) + self.resources.append(unit_obj) + # all destructable rocks + else: + self.destructables.append(unit_obj) # Alliance.Self.value = 1 elif alliance == 1: - own.append(unit) + self.own_units.append(unit_obj) # Alliance.Enemy.value = 4 elif alliance == 4: - enemy.append(unit) - - resources = minerals + geysers - visible_units = resources + destructables + enemy + own + watchtowers - - self.own_units: Units = Units.from_proto(own) - self.enemy_units: Units = Units.from_proto(enemy) - self.mineral_field: Units = Units.from_proto(minerals) - self.vespene_geyser: Units = Units.from_proto(geysers) - self.resources: Units = Units.from_proto(resources) - self.destructables: Units = Units.from_proto(destructables) - self.watchtowers: Units = Units.from_proto(watchtowers) - self.units: Units = Units.from_proto(visible_units) + self.enemy_units.append(unit_obj) self.upgrades: Set[UpgradeId] = {UpgradeId(upgrade) for upgrade in self.observation_raw.player.upgrade_ids} # Set of unit tags that died this step self.dead_units: Set[int] = {dead_unit_tag for dead_unit_tag in self.observation_raw.event.dead_units} - - self.blips: Set[Blip] = {Blip(unit) for unit in blipUnits} + # Set of enemy units detected by own sensor tower, as blips have less unit information than normal visible units + self.blips: Set[Blip] = {Blip(unit) for unit in self._blipUnits} # self.visibility[point]: 0=Hidden, 1=Fogged, 2=Visible self.visibility: PixelMap = PixelMap(self.observation_raw.map_state.visibility, mirrored=True) # self.creep[point]: 0=No creep, 1=creep diff --git a/sc2/position.py b/sc2/position.py index a7abbe52a..166bb90f5 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -18,95 +18,65 @@ def position(self) -> "Pointlike": def distance_to(self, target: Union["Unit", "Point2"]) -> float: """Calculate a single distance from a point or unit to another point or unit""" p = target.position - assert isinstance(p, Pointlike), f"p is not of type Pointlike" - return ((self[0] - p[0]) ** 2 + (self[1] - p[1]) ** 2) ** 0.5 + return math.hypot(self[0] - p[0], self[1] - p[1]) - def old_distance_to(self, p: Union["Unit", "Point2", "Point3"]) -> Union[int, float]: - p = p.position - assert isinstance(p, Pointlike), f"p is not of type Pointlike" - if self == p: - return 0 - return (sum(self.__class__((b - a) ** 2 for a, b in itertools.zip_longest(self, p, fillvalue=0)))) ** 0.5 - - def distance_to_point2(self, p2: "Point2") -> Union[int, float]: - """ Same as the function above, but should be 3-4 times faster because of the dropped asserts and - conversions and because it doesnt use a loop (itertools or zip). """ - return ((self[0] - p2[0]) ** 2 + (self[1] - p2[1]) ** 2) ** 0.5 + def distance_to_point2(self, p: "Point2") -> Union[int, float]: + """ Same as the function above, but should be a bit faster because of the dropped asserts + and conversion. """ + return math.hypot(self[0] - p[0], self[1] - p[1]) def _distance_squared(self, p2: "Point2") -> Union[int, float]: """ Function used to not take the square root as the distances will stay proportionally the same. This is to speed up the sorting process. """ return (self[0] - p2[0]) ** 2 + (self[1] - p2[1]) ** 2 - def is_closer_than(self, d: Union[int, float], p: Union["Unit", "Point2"]) -> bool: - """ Check if another point (or unit) is closer than the given distance. - More efficient than distance_to(p) < d.""" + def is_closer_than(self, distance: Union[int, float], p: Union["Unit", "Point2"]) -> bool: + """ Check if another point (or unit) is closer than the given distance. """ p = p.position - return self._distance_squared(p) < d ** 2 + return self.distance_to_point2(p) < distance - def is_further_than(self, d: Union[int, float], p: Union["Unit", "Point2"]) -> bool: - """ Check if another point (or unit) is further than the given distance. - More efficient than distance_to(p) > d. """ + def is_further_than(self, distance: Union[int, float], p: Union["Unit", "Point2"]) -> bool: + """ Check if another point (or unit) is further than the given distance. """ p = p.position - return self._distance_squared(p) > d ** 2 + return self.distance_to_point2(p) > distance def sort_by_distance(self, ps: Union["Units", List["Point2"]]) -> List["Point2"]: """ This returns the target points sorted as list. You should not pass a set or dict since those are not sortable. If you want to sort your units towards a point, use 'units.sorted_by_distance_to(point)' instead. """ - return sorted(ps, key=lambda p: self._distance_squared(p.position)) + return sorted(ps, key=lambda p: self.distance_to_point2(p.position)) def closest(self, ps: Union["Units", List["Point2"], Set["Point2"]]) -> Union["Unit", "Point2"]: """ This function assumes the 2d distance is meant """ assert ps, f"ps is empty" - closest_distance_squared = math.inf - for p2 in ps: - p2pos = p2 - if not isinstance(p2pos, Point2): - p2pos = p2.position - distance = (self[0] - p2pos[0]) ** 2 + (self[1] - p2pos[1]) ** 2 - if distance <= closest_distance_squared: - closest_distance_squared = distance - closest_element = p2 - return closest_element + return min(ps, key=lambda p: self.distance_to(p)) def distance_to_closest(self, ps: Union["Units", List["Point2"], Set["Point2"]]) -> Union[int, float]: """ This function assumes the 2d distance is meant """ assert ps, f"ps is empty" - closest_distance_squared = math.inf + closest_distance = math.inf for p2 in ps: - if not isinstance(p2, Point2): - p2 = p2.position - distance = (self[0] - p2[0]) ** 2 + (self[1] - p2[1]) ** 2 - if distance <= closest_distance_squared: - closest_distance_squared = distance - return closest_distance_squared ** 0.5 + p2 = p2.position + distance = self.distance_to(p2) + if distance <= closest_distance: + closest_distance = distance + return closest_distance def furthest(self, ps: Union["Units", List["Point2"], Set["Point2"]]) -> Union["Unit", "Pointlike"]: """ This function assumes the 2d distance is meant """ assert ps, f"ps is empty" - furthest_distance_squared = -math.inf - for p2 in ps: - p2pos = p2 - if not isinstance(p2pos, Point2): - p2pos = p2.position - distance = (self[0] - p2pos[0]) ** 2 + (self[1] - p2pos[1]) ** 2 - if furthest_distance_squared <= distance: - furthest_distance_squared = distance - furthest_element = p2 - return furthest_element + return max(ps, key=lambda p: self.distance_to(p)) def distance_to_furthest(self, ps: Union["Units", List["Point2"], Set["Point2"]]) -> Union[int, float]: """ This function assumes the 2d distance is meant """ assert ps, f"ps is empty" - furthest_distance_squared = -math.inf + furthest_distance = -math.inf for p2 in ps: - if not isinstance(p2, Point2): - p2 = p2.position - distance = (self[0] - p2[0]) ** 2 + (self[1] - p2[1]) ** 2 - if furthest_distance_squared <= distance: - furthest_distance_squared = distance - return furthest_distance_squared ** 0.5 + p2 = p2.position + distance = self.distance_to(p2) + if distance >= furthest_distance: + furthest_distance = distance + return furthest_distance def offset(self, p) -> "Pointlike": return self.__class__(a + b for a, b in itertools.zip_longest(self, p[: len(self)], fillvalue=0)) @@ -268,8 +238,8 @@ def __truediv__(self, other: Union[int, float, "Point2"]) -> "Point2": return self.__class__((self.x / other.x, self.y / other.y)) return self.__class__((self.x / other, self.y / other)) - def is_same_as(self, other: "Point2", dist=0.1) -> bool: - return self._distance_squared(other) <= dist ** 2 + def is_same_as(self, other: "Point2", dist=0.001) -> bool: + return self.distance_to_point2(other) <= dist def direction_vector(self, other: "Point2") -> "Point2": """ Converts a vector to a direction that can face vertically, horizontally or diagonal or be zero, e.g. (0, 0), (1, -1), (1, 0) """ diff --git a/sc2/power_source.py b/sc2/power_source.py index 35b237b31..007c70426 100644 --- a/sc2/power_source.py +++ b/sc2/power_source.py @@ -1,5 +1,6 @@ from .position import Point2 + class PowerSource: @classmethod def from_proto(cls, proto): @@ -13,11 +14,12 @@ def __init__(self, position, radius, unit_tag): self.unit_tag = unit_tag def covers(self, position): - return self.position._distance_squared(position) <= self.radius ** 2 + return self.position.distance_to(position) <= self.radius def __repr__(self): return f"PowerSource({self.position}, {self.radius})" + class PsionicMatrix: @classmethod def from_proto(cls, proto): diff --git a/sc2/unit.py b/sc2/unit.py index ec87b95c8..cc7b5edfa 100644 --- a/sc2/unit.py +++ b/sc2/unit.py @@ -405,33 +405,35 @@ def buffs(self) -> Set: def is_carrying_minerals(self) -> bool: """ Checks if a worker or MULE is carrying (gold-)minerals. """ return any( - buff in self.buffs for buff in {BuffId.CARRYMINERALFIELDMINERALS, BuffId.CARRYHIGHYIELDMINERALFIELDMINERALS} + buff in {BuffId.CARRYMINERALFIELDMINERALS, BuffId.CARRYHIGHYIELDMINERALFIELDMINERALS} for buff in self.buffs ) @property_immutable_cache def is_carrying_vespene(self) -> bool: """ Checks if a worker is carrying vespene gas. """ return any( - buff in self.buffs - for buff in { + buff + in { BuffId.CARRYHARVESTABLEVESPENEGEYSERGAS, BuffId.CARRYHARVESTABLEVESPENEGEYSERGASPROTOSS, BuffId.CARRYHARVESTABLEVESPENEGEYSERGASZERG, } + for buff in self.buffs ) @property_immutable_cache def is_carrying_resource(self) -> bool: """ Checks if a worker is carrying a resource. """ return any( - buff in self.buffs - for buff in { + buff + in { BuffId.CARRYMINERALFIELDMINERALS, BuffId.CARRYHIGHYIELDMINERALFIELDMINERALS, BuffId.CARRYHARVESTABLEVESPENEGEYSERGAS, BuffId.CARRYHARVESTABLEVESPENEGEYSERGASPROTOSS, BuffId.CARRYHARVESTABLEVESPENEGEYSERGASZERG, } + for buff in self.buffs ) @property_immutable_cache @@ -764,10 +766,7 @@ def target_in_range(self, target: "Unit", bonus_distance: Union[int, float] = 0) unit_attack_range = self.air_range else: return False - return ( - self.position._distance_squared(target.position) - <= (self.radius + target.radius + unit_attack_range + bonus_distance) ** 2 - ) + return self.distance_to(target) <= self.radius + target.radius + unit_attack_range + bonus_distance def has_buff(self, buff) -> bool: """ Checks if unit has buff 'buff'. """ diff --git a/sc2/units.py b/sc2/units.py index cbac896de..798f1437b 100644 --- a/sc2/units.py +++ b/sc2/units.py @@ -136,21 +136,18 @@ def in_attack_range_of(self, unit: Unit, bonus_distance: Union[int, float] = 0) def closest_distance_to(self, position: Union[Unit, Point2, Point3]) -> Union[int, float]: """ Returns the distance between the closest unit from this group to the target unit """ assert self, "Units object is empty" - if isinstance(position, Unit): - position = position.position + position = position.position return position.distance_to_closest(u.position for u in self) def furthest_distance_to(self, position: Union[Unit, Point2, Point3]) -> Union[int, float]: """ Returns the distance between the furthest unit from this group to the target unit """ assert self, "Units object is empty" - if isinstance(position, Unit): - position = position.position + position = position.position return position.distance_to_furthest(u.position for u in self) def closest_to(self, position: Union[Unit, Point2, Point3]) -> Unit: assert self, "Units object is empty" - if isinstance(position, Unit): - position = position.position + position = position.position return position.closest(self) def furthest_to(self, position: Union[Unit, Point2, Point3]) -> Unit: @@ -160,16 +157,12 @@ def furthest_to(self, position: Union[Unit, Point2, Point3]) -> Unit: return position.furthest(self) def closer_than(self, distance: Union[int, float], position: Union[Unit, Point2, Point3]) -> "Units": - if isinstance(position, Unit): - position = position.position - distance_squared = distance ** 2 - return self.filter(lambda unit: unit.position._distance_squared(position.to2) < distance_squared) + position = position.position + return self.filter(lambda unit: unit.distance_to(position.to2) < distance) def further_than(self, distance: Union[int, float], position: Union[Unit, Point2, Point3]) -> "Units": - if isinstance(position, Unit): - position = position.position - distance_squared = distance ** 2 - return self.filter(lambda unit: unit.position._distance_squared(position.to2) > distance_squared) + position = position.position + return self.filter(lambda unit: unit.distance_to(position.to2) > distance) def subgroup(self, units): return Units(units) @@ -183,7 +176,7 @@ def sorted(self, keyfn: callable, reverse: bool = False) -> "Units": def sorted_by_distance_to(self, position: Union[Unit, Point2], reverse: bool = False) -> "Units": """ This function should be a bit faster than using units.sorted(keyfn=lambda u: u.distance_to(position)) """ position = position.position - return self.sorted(keyfn=lambda unit: unit.position._distance_squared(position), reverse=reverse) + return self.sorted(keyfn=lambda unit: unit.distance_to(position), reverse=reverse) def tags_in(self, other: Union[Set[int], List[int], Dict[int, Any]]) -> "Units": """ Filters all units that have their tags in the 'other' set/list/dict """ diff --git a/test/test_pickled_data.py b/test/test_pickled_data.py index 615706ff0..e0e61ce12 100644 --- a/test/test_pickled_data.py +++ b/test/test_pickled_data.py @@ -114,7 +114,9 @@ def test_bot_ai(self, bot: BotAI): # TODO: Cache all expansion positions for a map and check if it is the same assert len(bot.expansion_locations) >= 12 # On N player maps, it is expected that there are N*X bases because of symmetry, at least for 1vs1 maps - assert len(bot.expansion_locations) % (len(bot.enemy_start_locations) + 1) == 0, f"{set(bot.expansion_locations.keys())}" + assert ( + len(bot.expansion_locations) % (len(bot.enemy_start_locations) + 1) == 0 + ), f"{set(bot.expansion_locations.keys())}" # Test if bot start location is in expansion locations assert bot.townhalls.random.position in set( bot.expansion_locations.keys() @@ -528,27 +530,26 @@ def test_units(self, bot: BotAI): assert townhalls.prefer_idle @given( - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), ) @settings(max_examples=500) def test_position_pointlike(self, bot: BotAI, x1, y1, x2, y2, x3, y3): pos1 = Point2((x1, y1)) pos2 = Point2((x2, y2)) pos3 = Point2((x3, y3)) + epsilon = 1e-3 assert pos1.position == pos1 dist = ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5 - assert pos1.distance_to(pos2) == dist - assert pos1.old_distance_to(pos2) == dist - assert pos1.distance_to_point2(pos2) == dist - assert pos1._distance_squared(pos2) ** 0.5 == dist + assert abs(pos1.distance_to(pos2) - dist) <= epsilon + assert abs(pos1.distance_to_point2(pos2) - dist) <= epsilon + assert abs(pos1._distance_squared(pos2) ** 0.5 - dist) <= epsilon - epsilon = 1e-1 - if epsilon < dist < 1e10: + if epsilon < dist < 1e5: assert pos1.is_closer_than(dist + epsilon, pos2) assert pos1.is_further_than(dist - epsilon, pos2) @@ -599,10 +600,10 @@ def test_position_pointlike(self, bot: BotAI, x1, y1, x2, y2, x3, y3): assert isinstance(hash(pos3), int) @given( - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), ) @settings(max_examples=500) def test_position_point2(self, bot: BotAI, x1, y1, x2, y2): @@ -641,9 +642,9 @@ def test_position_point2(self, bot: BotAI, x1, y1, x2, y2): assert pos1.unit_axes_towards(pos2) == pos1.direction_vector(pos2) @given( - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), ) @settings(max_examples=10) def test_position_point2(self, bot: BotAI, x1, y1, z1): @@ -651,7 +652,7 @@ def test_position_point2(self, bot: BotAI, x1, y1, z1): assert pos1.z == z1 assert pos1.to3 == pos1 - @given(st.integers(min_value=-1e10, max_value=1e10), st.integers(min_value=-1e10, max_value=1e10)) + @given(st.integers(min_value=-1e5, max_value=1e5), st.integers(min_value=-1e5, max_value=1e5)) @settings(max_examples=20) def test_position_size(self, bot: BotAI, w, h): size = Size((w, h)) @@ -659,10 +660,10 @@ def test_position_size(self, bot: BotAI, w, h): assert size.height == h @given( - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), - st.integers(min_value=-1e10, max_value=1e10), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), + st.integers(min_value=-1e5, max_value=1e5), ) @settings(max_examples=20) def test_position_rect(self, bot: BotAI, x, y, w, h):