331 lines
11 KiB
Python
331 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
download.py - Download/repair model files, update model.yaml metadata,
|
||
and commit/push changes with proper Git LFS handling (no bash script needed).
|
||
|
||
Usage:
|
||
./tools/download.py models/llama-2-7b-chat/model.yaml
|
||
|
||
What this does:
|
||
- (Re)runs snapshot_download with resume support, so partially fetched directories
|
||
get completed instead of being skipped.
|
||
- Avoids adding Hugging Face housekeeping like ".cache/**" to your YAML.
|
||
- Updates YAML after each variant with fresh file list + total size.
|
||
- Tracks LFS via sensible patterns (plus a size threshold fallback) using
|
||
repo-relative paths so it actually applies.
|
||
- Runs a built-in cleanup step (commit, push, optional LFS push, and prune),
|
||
replacing the old cleanup.sh.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import sys
|
||
import yaml
|
||
import subprocess
|
||
from pathlib import Path
|
||
from typing import Iterable, List, Optional
|
||
from huggingface_hub import snapshot_download
|
||
|
||
# ----------------------------
|
||
# Configuration
|
||
# ----------------------------
|
||
|
||
LFS_PATTERNS: list[str] = [
|
||
# Extensions commonly used for model artifacts
|
||
"*.safetensors",
|
||
"*.bin",
|
||
"*.pt",
|
||
"*.gguf",
|
||
"*.onnx",
|
||
"*.ckpt",
|
||
"*.tensors",
|
||
"*.npz",
|
||
"*.tar",
|
||
"*.tar.gz",
|
||
"*.zip",
|
||
]
|
||
|
||
SIZE_THRESHOLD_BYTES = 1_000_000 # 1 MB fallback if a file doesn't match any pattern
|
||
|
||
# By default we skip pushing all LFS objects (same as prior bash script).
|
||
# Set env GIT_LFS_PUSH_ALL=1 to force a full "git lfs push origin --all".
|
||
LFS_PUSH_ALL = os.environ.get("GIT_LFS_PUSH_ALL", "0") == "1"
|
||
|
||
|
||
# ----------------------------
|
||
# Small subprocess helpers
|
||
# ----------------------------
|
||
|
||
def run(cmd: list[str], check: bool = True, cwd: Optional[Path] = None) -> None:
|
||
subprocess.run(cmd, check=check, cwd=str(cwd) if cwd else None)
|
||
|
||
|
||
def run_capture(cmd: list[str], cwd: Optional[Path] = None) -> str:
|
||
out = subprocess.check_output(cmd, cwd=str(cwd) if cwd else None, stderr=subprocess.DEVNULL)
|
||
return out.decode().strip()
|
||
|
||
|
||
# ----------------------------
|
||
# Git / LFS utilities
|
||
# ----------------------------
|
||
|
||
def ensure_repo_root() -> Path:
|
||
"""
|
||
Ensure we're in a git repo; install LFS filters locally; return repo root.
|
||
"""
|
||
try:
|
||
subprocess.run(
|
||
["git", "rev-parse", "--is-inside-work-tree"],
|
||
check=True,
|
||
stdout=subprocess.DEVNULL,
|
||
stderr=subprocess.DEVNULL,
|
||
)
|
||
# Make sure LFS filters are active in this repo (idempotent)
|
||
subprocess.run(
|
||
["git", "lfs", "install", "--local"],
|
||
check=False,
|
||
stdout=subprocess.DEVNULL,
|
||
stderr=subprocess.DEVNULL,
|
||
)
|
||
root = run_capture(["git", "rev-parse", "--show-toplevel"])
|
||
return Path(root)
|
||
except Exception:
|
||
print("⚠️ Not inside a Git repository? Git/LFS steps may fail.", file=sys.stderr)
|
||
return Path.cwd()
|
||
|
||
|
||
def repo_relative_path(repo_root: Path, p: Path) -> Path:
|
||
"""
|
||
Return a path to p relative to repo_root. Works even if p is not a strict subpath
|
||
(falls back to os.path.relpath).
|
||
"""
|
||
try:
|
||
return p.resolve().relative_to(repo_root.resolve())
|
||
except Exception:
|
||
# Fallback (handles symlinks / different mounts)
|
||
return Path(os.path.relpath(p.resolve(), repo_root.resolve()))
|
||
|
||
|
||
def lfs_track_patterns(patterns: Iterable[str]) -> None:
|
||
"""
|
||
Track a set of glob patterns in Git LFS (idempotent).
|
||
"""
|
||
for patt in patterns:
|
||
try:
|
||
run(["git", "lfs", "track", patt], check=False)
|
||
except Exception:
|
||
# Non-fatal: we’ll still fall back to per-file size rule below.
|
||
pass
|
||
|
||
|
||
def lfs_track_file(repo_root: Path, path_in_repo: Path) -> None:
|
||
"""
|
||
Track an individual file in Git LFS using a repo-relative path.
|
||
"""
|
||
# Normalize to POSIX-like string for .gitattributes consistency
|
||
rel = str(path_in_repo.as_posix())
|
||
try:
|
||
run(["git", "lfs", "track", rel], check=False, cwd=repo_root)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def git_status_has_changes(repo_root: Path) -> bool:
|
||
try:
|
||
status = run_capture(["git", "status", "--porcelain=v1"], cwd=repo_root)
|
||
return bool(status.strip())
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def git_stage_and_commit_push(
|
||
repo_root: Path,
|
||
scope_paths: list[Path],
|
||
commit_message: str,
|
||
lfs_push_all: bool = False,
|
||
) -> None:
|
||
"""
|
||
Stages .gitattributes + renormalizes only the provided scope paths (e.g., 'models/…'),
|
||
then commits, pushes, optionally pushes all LFS objects, and finally prunes LFS.
|
||
"""
|
||
# Stage .gitattributes explicitly (ignore failures)
|
||
try:
|
||
run(["git", "add", ".gitattributes"], check=False, cwd=repo_root)
|
||
except Exception:
|
||
pass
|
||
|
||
# Renormalize only the relevant directories to avoid sweeping the whole repo.
|
||
# If scope_paths is empty, fall back to full repo (conservative).
|
||
if scope_paths:
|
||
for sp in scope_paths:
|
||
rel = repo_relative_path(repo_root, sp)
|
||
run(["git", "add", "--renormalize", str(rel)], check=False, cwd=repo_root)
|
||
else:
|
||
run(["git", "add", "--renormalize", "."], check=False, cwd=repo_root)
|
||
|
||
# If nothing is staged, skip
|
||
staged_is_empty = False
|
||
try:
|
||
# 'git diff --cached --quiet' exits 0 when no staged changes
|
||
subprocess.run(["git", "diff", "--cached", "--quiet"], cwd=str(repo_root))
|
||
staged_is_empty = True
|
||
except Exception:
|
||
staged_is_empty = False
|
||
|
||
if staged_is_empty:
|
||
print("No staged changes after normalization. Skipping commit and push.")
|
||
return
|
||
|
||
print("Committing and pushing changes...")
|
||
run(["git", "commit", "-m", commit_message], cwd=repo_root)
|
||
|
||
# Push main refs
|
||
run(["git", "push"], cwd=repo_root)
|
||
|
||
# Optionally ensure all LFS objects are uploaded
|
||
if lfs_push_all:
|
||
try:
|
||
run(["git", "lfs", "push", "origin", "--all"], check=True, cwd=repo_root)
|
||
except subprocess.CalledProcessError as e:
|
||
print(f"⚠️ 'git lfs push --all' failed: {e}. Continuing.", file=sys.stderr)
|
||
|
||
# Prune local LFS to save disk
|
||
try:
|
||
run(["git", "lfs", "prune", "--force"], check=False, cwd=repo_root)
|
||
except Exception as e:
|
||
print(f"⚠️ 'git lfs prune' failed: {e}. Continuing.", file=sys.stderr)
|
||
|
||
print("✅ Cleanup complete.")
|
||
|
||
|
||
# ----------------------------
|
||
# Filesystem helpers
|
||
# ----------------------------
|
||
|
||
def list_files_under(root: Path) -> list[Path]:
|
||
"""
|
||
Recursively collect files under `root`, skipping housekeeping dirs.
|
||
"""
|
||
skip_dirs = {".git", ".cache", ".hf_mirror_cache"}
|
||
files: list[Path] = []
|
||
for p in root.rglob("*"):
|
||
if not p.is_file():
|
||
continue
|
||
rel_parts = p.relative_to(root).parts
|
||
# Skip files if any parent is a hidden/skip dir
|
||
if any(part in skip_dirs or part.startswith(".") for part in rel_parts[:-1]):
|
||
continue
|
||
files.append(p)
|
||
return files
|
||
|
||
|
||
# ----------------------------
|
||
# Main routine
|
||
# ----------------------------
|
||
|
||
def main() -> None:
|
||
if len(sys.argv) != 2:
|
||
print(f"Usage: {sys.argv[0]} <path-to-model.yaml>", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
model_yaml_path = Path(sys.argv[1])
|
||
if not model_yaml_path.exists():
|
||
print(f"Model YAML not found: {model_yaml_path}", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
repo_root = ensure_repo_root()
|
||
|
||
# Load YAML
|
||
with open(model_yaml_path, "r", encoding="utf-8") as f:
|
||
data = yaml.safe_load(f) or {}
|
||
|
||
model_dir = model_yaml_path.parent
|
||
|
||
# Proactively set up LFS tracking by patterns (idempotent)
|
||
lfs_track_patterns(LFS_PATTERNS)
|
||
|
||
# Iterate formats & variants
|
||
formats = (data.get("model") or {}).get("formats") or []
|
||
for fmt in formats:
|
||
variants = fmt.get("variants") or []
|
||
for variant in variants:
|
||
variant_id = variant.get("id")
|
||
hf_repo = variant.get("hf_repo")
|
||
|
||
if not hf_repo or not variant_id:
|
||
continue
|
||
|
||
dest_path = model_dir / variant_id
|
||
dest_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
repo_id = hf_repo.replace("https://huggingface.co/", "")
|
||
print(f"\n[DL] Downloading/resuming variant '{variant_id}' from '{repo_id}' into '{dest_path}'")
|
||
|
||
# Always call snapshot_download with resume enabled. Filter out .cache.
|
||
try:
|
||
snapshot_download(
|
||
repo_id=repo_id,
|
||
local_dir=str(dest_path),
|
||
local_dir_use_symlinks=False,
|
||
resume_download=True, # explicit
|
||
ignore_patterns=[".cache/**"], # prevent housekeeping into tree
|
||
)
|
||
except Exception as e:
|
||
print(f"❌ snapshot_download failed for {variant_id}: {e}", file=sys.stderr)
|
||
raise
|
||
|
||
# Scan files, compute size, and ensure big files are tracked by LFS
|
||
files_list: list[str] = []
|
||
total_size_bytes = 0
|
||
|
||
for p in list_files_under(dest_path):
|
||
rel_to_model = p.relative_to(model_dir)
|
||
files_list.append(str(rel_to_model).replace("\\", "/"))
|
||
try:
|
||
size = p.stat().st_size
|
||
except FileNotFoundError:
|
||
# if a file was removed mid-scan, skip it
|
||
continue
|
||
total_size_bytes += size
|
||
|
||
# Fallback: ensure big files get tracked even if patterns miss them
|
||
if size > SIZE_THRESHOLD_BYTES:
|
||
rel_to_repo = repo_relative_path(repo_root, p)
|
||
lfs_track_file(repo_root, rel_to_repo)
|
||
|
||
files_list.sort()
|
||
variant["files"] = files_list
|
||
variant["size_bytes"] = int(total_size_bytes)
|
||
|
||
# Save updated YAML progressively after each variant
|
||
with open(model_yaml_path, "w", encoding="utf-8") as f:
|
||
yaml.dump(data, f, sort_keys=False, allow_unicode=True)
|
||
|
||
print(f"✅ Updated {model_yaml_path} for variant '{variant_id}'")
|
||
|
||
# ---- Built-in cleanup (replaces cleanup.sh) ----
|
||
print(f"🧹 Running cleanup for {variant_id}...")
|
||
try:
|
||
# Only scope renormalization to model_dir to keep things fast.
|
||
git_changed = git_status_has_changes(repo_root)
|
||
if not git_changed:
|
||
print("No new files or changes to commit. Skipping commit and push.")
|
||
else:
|
||
commit_message = f"Add/update model files for {model_dir.name}/{variant_id}"
|
||
git_stage_and_commit_push(
|
||
repo_root=repo_root,
|
||
scope_paths=[model_dir],
|
||
commit_message=commit_message,
|
||
lfs_push_all=LFS_PUSH_ALL,
|
||
)
|
||
except subprocess.CalledProcessError as e:
|
||
print(f"❌ Cleanup failed (continue to next variant): {e}", file=sys.stderr)
|
||
# Continue to next variant
|
||
|
||
print(f"\n✅ Download and YAML update complete for {model_yaml_path}.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|