English 简体中文 繁體中文 한국 사람 日本語 Deutsch русский بالعربية TÜRKÇE português คนไทย french
查看: 6|回复: 0

Python · Jax | 在 python 3.8 上安装 jax,运行 offline RL 的 IQL

[复制链接]
查看: 6|回复: 0

Python · Jax | 在 python 3.8 上安装 jax,运行 offline RL 的 IQL

[复制链接]
查看: 6|回复: 0

392

主题

0

回帖

1184

积分

管理员

积分
1184
溯源设备

392

主题

0

回帖

1184

积分

管理员

积分
1184
2025-2-5 18:04:03 | 显示全部楼层 |阅读模式
致谢师兄的 jax 环境,完全按照师兄的 conda_env.yml 配置的
(如何导出其他环境的 conda_env.yml:Conda | 如何(在新服务器上)复制一份旧服务器的 conda 环境,Linux 服务器

目录


<hr>首先,新建一个 conda 环境:
conda create -n jax_env python==3.8conda activate jax_env(如何配置 conda:Conda | 如何在 Linux 服务器安装 conda
01 安装各种库

直接 pip 安装:
pip install numpy==1.21.6 torch==1.13.1 wandb==0.15.10 \transformers==4.30.2 typing-extensions==4.7.1 optax==0.1.4 \jax==0.3.24 flax==0.6.0 cloudpickle==2.2.1 distrax==0.1.3 \glfw==2.6.2 gym==0.15.7 ml-collections==0.1.1 tensorboardx==2.1 \protobuf==3.20.1 ujson==5.7.0 pynvml02 安装 jax

jax 把自己的库放在了网站上:
要安装 0.3.24 的 jax,可以运行:
pip install "jax[cuda11_cudnn82]==0.3.24" \-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html需要注意:

  • jax jaxlib optax flax 等库,它们的版本有对应关系,可按照这篇博客的参考版本安装;
  • 需要 pip install cloudpickle==2.2.1,好像很容易安装成 1.2.2 版本,最后要检查一下版本;protobuf==3.20.1 这个版本也是。
  • 编译的时候,因为 ptxas 版本太低报错,可以运行 which ptxas,查看现在在用哪个 ptxas 版本。如果发现在用老 cuda 版本,则去改 path,修改 ~/.bashrc,添加
export PATH="/usr/local/cuda-{版本号}/bin:$PATH"export LD_LIBRARY_PATH="/usr/local/cuda-{版本号}/lib64:$LD_LIBRARY_PATH"# cuda 版本号可以看 /usr/local 目录里有哪些版本,我用的是 11.703 安装 dm_control metaworld d4rl

需要先安装 MuJoCo,可参见这篇:Python · MuJoCo | MuJoCo 与 mujoco_py 的版本对应,以及安装 Cython<3
先把 dm_control metaworld d4rl 这三个库拿下来:
git clone git@github.com:Farama-Foundation/Metaworld.gitgit clone git@github.com:Farama-Foundation/D4RL.gitgit clone git@github.com:denisyarats/dmc2gym.git然后分别进入它们的路径,执行 pip install -e . 即可。
04 测试

我跑的是 https://github.com/csmile-1006/PreferenceTransformer 这个库,它里面也有 IQL 的 jax 实现,所以这个环境应该是能跑 IQL jax 的)
05 各种库的参考版本

以下是一个参考环境的版本:
name: jax_envchannels:  - defaultsdependencies:  - _libgcc_mutex=0.1=main  - ca-certificates=2023.08.22=h06a4308_0  - certifi=2022.12.7=py37h06a4308_0  - ld_impl_linux-64=2.38=h1181459_1  - libffi=3.3=he6710b0_2  - libgcc-ng=9.1.0=hdf63c60_0  - libstdcxx-ng=9.1.0=hdf63c60_0  - ncurses=6.3=h7f8727e_2  - openssl=1.1.1w=h7f8727e_0  - pip=22.3.1=py37h06a4308_0  - python=3.7.13=h12debd9_0  - readline=8.1.2=h7f8727e_1  - setuptools=65.6.3=py37h06a4308_0  - sqlite=3.38.5=hc218d9a_0  - tk=8.6.12=h1ccaba5_0  - wheel=0.38.4=py37h06a4308_0  - xz=5.2.5=h7f8727e_1  - zlib=1.2.12=h7f8727e_2  - pip:    - absl-py==1.4.0    - appdirs==1.4.4    - beautifulsoup4==4.12.2    - cffi==1.15.1    - charset-normalizer==3.2.0    - chex==0.1.5    - click==8.1.7    - cloudpickle==2.2.1    - colorama==0.4.6    - commonmark==0.9.1    - contextlib2==21.6.0    - cycler==0.11.0    - cython==3.0.2    - decorator==5.1.1    - distrax==0.1.3    - dm-control==1.0.13    - dm-env==1.6    - dm-tree==0.1.8    - docker-pycreds==0.4.0    - etils==0.9.0    - fasteners==0.18    - filelock==3.12.2    - flax==0.6.0    - fonttools==4.38.0    - fsspec==2023.1.0    - future==0.18.3    - gast==0.5.4    - gdown==4.7.1    - gitdb==4.0.10    - gitpython==3.1.36    - glfw==2.6.2    - gym==0.15.7    - gym-notices==0.0.8    - h5py==3.8.0    - huggingface-hub==0.16.4    - idna==3.4    - imageio==2.31.2    - imageio-ffmpeg==0.4.9    - importlib-metadata==6.7.0    - importlib-resources==5.12.0    - jax==0.3.24    - jaxlib==0.3.24+cuda11.cudnn82    - joblib==1.3.2    - kiwisolver==1.4.5    - labmaze==1.0.6    - lxml==4.9.3    - matplotlib==3.5.3    - ml-collections==0.1.1    - msgpack==1.0.5    - mujoco==2.3.6    - mujoco-py==2.0.2.13    - numpy==1.21.6    - nvidia-cublas-cu11==11.10.3.66    - nvidia-cuda-nvrtc-cu11==11.7.99    - nvidia-cuda-runtime-cu11==11.7.99    - nvidia-cudnn-cu11==8.5.0.96    - opt-einsum==3.3.0    - optax==0.1.4    - packaging==23.1    - pathtools==0.1.2    - pillow==9.5.0    - protobuf==3.20.1    - psutil==5.9.5    - pybullet==3.2.5    - pycparser==2.21    - pyglet==1.5.0    - pygments==2.16.1    - pyopengl==3.1.7    - pyparsing==3.1.1    - pysocks==1.7.1    - python-dateutil==2.8.2    - pyyaml==6.0.1    - regex==2023.8.8    - requests==2.31.0    - rich==11.2.0    - safetensors==0.3.3    - scikit-learn==1.0.2    - scipy==1.7.3    - sentry-sdk==1.31.0    - setproctitle==1.3.2    - six==1.16.0    - smmap==5.0.1    - soupsieve==2.4.1    - tensorboardx==2.1    - tensorflow-probability==0.19.0    - termcolor==2.3.0    - threadpoolctl==3.1.0    - tokenizers==0.13.3    - toolz==0.12.0    - torch==1.13.1    - tqdm==4.66.1    - transformers==4.30.2    - typing-extensions==4.7.1    - ujson==5.7.0    - urllib3==2.0.4    - wandb==0.15.10    - zipp==3.15.0prefix: /home/user_name/miniconda3/envs/jax
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|智能设备 | 粤ICP备2024353841号-1

GMT+8, 2025-3-10 15:24 , Processed in 3.808887 second(s), 27 queries .

Powered by 智能设备

©2025

|网站地图