This repository was archived by the owner on Sep 10, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 3 files changed +17
-7
lines changed Expand file tree Collapse file tree 3 files changed +17
-7
lines changed Original file line number Diff line number Diff line change 8181 REQUIREMENTS_TO_INSTALL=(
8282 torch==" 2.7.0.${PYTORCH_NIGHTLY_VERSION} "
8383 torchvision==" 0.22.0.${VISION_NIGHTLY_VERSION} "
84- torchtune==" 0.6.0"
84+ # torchtune=="0.6.0" # no 0.6.0 on xpu nightly
8585 )
8686else
8787 REQUIREMENTS_TO_INSTALL=(
115115 " ${REQUIREMENTS_TO_INSTALL[@]} "
116116)
117117
118+ # Temporatory instal torchtune nightly from cpu nightly link since no torchtune nightly for xpu now
119+ # TODO: Change to install torchtune from xpu nightly link, once torchtune xpu nightly is ready
120+ if [[ -x " $( command -v xpu-smi) " ]];
121+ then
122+ (
123+ set -x
124+ $PIP_EXECUTABLE install --extra-index-url " https://download.pytorch.org/whl/nightly/cpu" \
125+ torchtune==" 0.6.0.${TUNE_NIGHTLY_VERSION} "
126+ )
127+ fi
128+
118129# For torchao need to install from github since nightly build doesn't have macos build.
119130# TODO: Remove this and install nightly build, once it supports macos
120131(
Original file line number Diff line number Diff line change 2929from torchchat .utils .build_utils import (
3030 device_sync ,
3131 is_cpu_device ,
32- is_cuda_or_cpu_device ,
32+ is_cuda_or_cpu_or_xpu_device ,
3333 name_to_dtype ,
3434)
3535from torchchat .utils .measure_time import measure_time
@@ -539,7 +539,7 @@ def _initialize_model(
539539 _set_gguf_kwargs (builder_args , is_et = is_pte , context = "generate" )
540540
541541 if builder_args .dso_path :
542- if not is_cuda_or_cpu_device (builder_args .device ):
542+ if not is_cuda_or_cpu_or_xpu_device (builder_args .device ):
543543 print (
544544 f"Cannot load specified DSO to { builder_args .device } . Attempting to load model to CPU instead"
545545 )
@@ -573,7 +573,7 @@ def do_nothing(max_batch_size, max_seq_length):
573573 raise RuntimeError (f"Failed to load AOTI compiled { builder_args .dso_path } " )
574574
575575 elif builder_args .aoti_package_path :
576- if not is_cuda_or_cpu_device (builder_args .device ):
576+ if not is_cuda_or_cpu_or_xpu_device (builder_args .device ):
577577 print (
578578 f"Cannot load specified PT2 to { builder_args .device } . Attempting to load model to CPU instead"
579579 )
Original file line number Diff line number Diff line change @@ -303,6 +303,5 @@ def get_device(device) -> str:
303303def is_cpu_device (device ) -> bool :
304304 return device == "" or str (device ) == "cpu"
305305
306-
307- def is_cuda_or_cpu_device (device ) -> bool :
308- return is_cpu_device (device ) or ("cuda" in str (device ))
306+ def is_cuda_or_cpu_or_xpu_device (device ) -> bool :
307+ return is_cpu_device (device ) or ("cuda" in str (device )) or ("xpu" in str (device ))
You can’t perform that action at this time.
0 commit comments