1414 dict (torch = "1.8.0" , torchvision = "0.9.0" , torchtext = "0.9.0" ),
1515 dict (torch = "1.7.1" , torchvision = "0.8.2" , torchtext = "0.8.1" ),
1616 dict (torch = "1.7.0" , torchvision = "0.8.1" , torchtext = "0.8.0" ),
17- dict (torch = "1.6.0" , torchvision = "0.7.0" , torchtext = "0.7" ),
1817]
1918
2019
@@ -33,28 +32,59 @@ def find_latest(ver: str) -> Dict[str, str]:
3332 raise ValueError (f"Missing { ver } in { VERSIONS } " )
3433
3534
36- def main (path_req : str , torch_version : Optional [str ] = None ) -> None :
35+ def main (req : str , torch_version : Optional [str ] = None ) -> str :
3736 if not torch_version :
3837 import torch
3938
4039 torch_version = torch .__version__
4140 assert torch_version , f"invalid torch: { torch_version } "
4241
43- with open (path_req ) as fp :
44- req = fp .read ()
45- # remove comments
46- req = re .sub (rf"\s*#.*{ os .linesep } " , os .linesep , req )
42+ # remove comments and strip whitespace
43+ req = re .sub (rf"\s*#.*{ os .linesep } " , os .linesep , req ).strip ()
4744
4845 latest = find_latest (torch_version )
4946 for lib , version in latest .items ():
50- replace = f"{ lib } =={ version } " if version else lib
51- replace += os .linesep
52- req = re .sub (rf"{ lib } [>=]*[\d\.]*{ os .linesep } " , replace , req )
47+ replace = f"{ lib } =={ version } " if version else ""
48+ req = re .sub (rf"\b{ lib } (?!\w).*" , replace , req )
5349
54- print (req ) # on purpose - to debug
55- with open (path_req , "w" ) as fp :
56- fp .write (req )
50+ return req
51+
52+
53+ def test ():
54+ requirements = """
55+ torch>=1.2.*
56+ torch==1.2.3
57+ torch==1.4
58+ torch
59+ future>=0.17.1
60+ pytorch==1.5.6+123dev0
61+ torchvision
62+ torchmetrics>=0.4.1
63+ """
64+ expected = """
65+ torch==1.9.1
66+ torch==1.9.1
67+ torch==1.9.1
68+ torch==1.9.1
69+ future>=0.17.1
70+ pytorch==1.5.6+123dev0
71+ torchvision==0.10.1
72+ torchmetrics>=0.4.1
73+ """ .strip ()
74+ actual = main (requirements , "1.9" )
75+ assert actual == expected , (actual , expected )
5776
5877
5978if __name__ == "__main__" :
60- main (* sys .argv [1 :])
79+ test () # sanity check
80+
81+ if len (sys .argv ) == 3 :
82+ requirements_path , torch_version = sys .argv [1 :]
83+ else :
84+ requirements_path , torch_version = sys .argv [1 ], None
85+
86+ with open (requirements_path , "r+" ) as fp :
87+ requirements = fp .read ()
88+ requirements = main (requirements , torch_version )
89+ print (requirements ) # on purpose - to debug
90+ fp .write (requirements )
0 commit comments