option(MACE_DOWNLOAD_LIBTORCH "Download libtorch with FetchContent" OFF)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

if(DEFINED LIBTORCH_ROOT)
  message(STATUS "Using user-provided libtorch at: ${LIBTORCH_ROOT}")
elseif(MACE_DOWNLOAD_LIBTORCH)
  include(FetchContent)

  set(LIBTORCH_VERSION "2.3.0")
  # CPU-only 版の例。必要なら CUDA 版の URL に差し替え
  set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}.zip")

  FetchContent_Declare(
    libtorch
    URL      ${LIBTORCH_URL}
    DOWNLOAD_EXTRACT_TIMESTAMP TRUE
  )
  FetchContent_MakeAvailable(libtorch)

  set(LIBTORCH_ROOT "${libtorch_SOURCE_DIR}")
  message(STATUS "Downloaded libtorch to: ${LIBTORCH_ROOT}")
else()
  message(FATAL_ERROR
    "LIBTORCH_ROOT is not set and MACE_DOWNLOAD_LIBTORCH is OFF.\n"
    "Set -DLIBTORCH_ROOT=/path/to/libtorch or enable -DMACE_DOWNLOAD_LIBTORCH=ON.")
endif()

list(APPEND CMAKE_PREFIX_PATH "${LIBTORCH_ROOT}/share/cmake/Torch")

if(CMAKE_CUDA_ARCHITECTURES AND NOT DEFINED TORCH_CUDA_ARCH_LIST)
  set(_torch_cuda_arch_list "")
  foreach(_arch IN LISTS CMAKE_CUDA_ARCHITECTURES)
    string(REPLACE "-real" "" _arch "${_arch}")
    string(REPLACE "-virtual" "" _arch "${_arch}")
    if(_arch MATCHES "^[0-9]+$")
      string(LENGTH "${_arch}" _arch_len)
      if(_arch_len EQUAL 2)
        string(SUBSTRING "${_arch}" 0 1 _arch_major)
        string(SUBSTRING "${_arch}" 1 1 _arch_minor)
        list(APPEND _torch_cuda_arch_list "${_arch_major}.${_arch_minor}")
      elseif(_arch_len EQUAL 3)
        string(SUBSTRING "${_arch}" 0 2 _arch_major)
        string(SUBSTRING "${_arch}" 2 1 _arch_minor)
        list(APPEND _torch_cuda_arch_list "${_arch_major}.${_arch_minor}")
      else()
        list(APPEND _torch_cuda_arch_list "${_arch}")
      endif()
    else()
      list(APPEND _torch_cuda_arch_list "${_arch}")
    endif()
  endforeach()
  list(REMOVE_DUPLICATES _torch_cuda_arch_list)
  set(TORCH_CUDA_ARCH_LIST "${_torch_cuda_arch_list}")
  message(STATUS "Using TORCH_CUDA_ARCH_LIST from CMAKE_CUDA_ARCHITECTURES: ${TORCH_CUDA_ARCH_LIST}")
endif()

find_package(Torch REQUIRED)

# Imported targets created by find_package(Torch) are directory-scoped by
# default. Promote the ones used transitively so the parent PIMD targets can
# link against mace_fortran without target visibility errors.
foreach(_torch_target
    c10
    c10_cuda
    torch
    torch_cpu
    torch_cpu_library
    torch_cuda
    torch_cuda_library)
  if(TARGET ${_torch_target})
    set_property(TARGET ${_torch_target} PROPERTY IMPORTED_GLOBAL TRUE)
  endif()
endforeach()

add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1)

set(MACE_LIBTORCH_BACKEND_OUTPUT_DIR "${CMAKE_BINARY_DIR}/lib/mace")
file(MAKE_DIRECTORY "${MACE_LIBTORCH_BACKEND_OUTPUT_DIR}")

add_subdirectory(src)
