135 lines
4.6 KiB
Python
135 lines
4.6 KiB
Python
from huggingface_hub import HfApi, HfFileSystem
|
|
from pathlib import Path
|
|
import yaml
|
|
import requests
|
|
import os
|
|
from datetime import datetime
|
|
from collections import defaultdict
|
|
import re
|
|
import sys
|
|
|
|
|
|
def generate_model_bundle(repo_id: str, output_dir: str):
|
|
api = HfApi()
|
|
fs = HfFileSystem()
|
|
model_info = api.model_info(repo_id)
|
|
|
|
# Create output path
|
|
out_path = Path(output_dir)
|
|
out_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# ----- 1. Fetch metadata -----
|
|
model_card = model_info.cardData or {}
|
|
tags = model_info.tags or []
|
|
files = api.list_repo_files(repo_id)
|
|
|
|
# ----- 2. Filter files -----
|
|
model_files = [f for f in files if f.endswith(".gguf") or f.endswith(".safetensors")]
|
|
tokenizer_files = [f for f in files if "tokenizer" in f.lower()]
|
|
license_file = next((f for f in files if "license" in f.lower()), None)
|
|
|
|
# ----- 3. Fetch README -----
|
|
readme_url = f"https://huggingface.co/{repo_id}/raw/main/README.md"
|
|
readme_path = out_path / "README.md"
|
|
try:
|
|
r = requests.get(readme_url)
|
|
r.raise_for_status()
|
|
readme_path.write_text(r.text)
|
|
except Exception:
|
|
readme_path.write_text(f"# README for {repo_id}\n(Not found on HuggingFace)")
|
|
|
|
# ----- 4. Fetch LICENSE -----
|
|
if license_file:
|
|
license_text = api.hf_hub_download(repo_id, license_file)
|
|
license_dst = out_path / Path(license_file).name
|
|
Path(license_dst).write_text(Path(license_text).read_text())
|
|
|
|
# ----- 5. Build variant groups -----
|
|
variants = []
|
|
shard_groups = defaultdict(list)
|
|
unsharded_files = []
|
|
|
|
for f in model_files:
|
|
match = re.match(r"(.+)-\d+-of-\d+\.safetensors$", f)
|
|
if match:
|
|
prefix = match.group(1)
|
|
shard_groups[prefix].append(f)
|
|
else:
|
|
unsharded_files.append(f)
|
|
|
|
for prefix, files_group in shard_groups.items():
|
|
total_size = sum(fs.info(f"hf://{repo_id}/{f}").get("size", 0) for f in files_group)
|
|
context_length = 128000 if "128k" in prefix.lower() else 4096
|
|
bits = 16 # Assume safetensors shards are FP16
|
|
|
|
variants.append({
|
|
"id": prefix,
|
|
"label": prefix,
|
|
"bits": bits,
|
|
"context_length": context_length,
|
|
"size_bytes": total_size,
|
|
"hf_repo": f"https://huggingface.co/{repo_id}",
|
|
"files": sorted(files_group)
|
|
})
|
|
|
|
for f in unsharded_files:
|
|
ext = Path(f).suffix
|
|
size_bytes = fs.info(f"hf://{repo_id}/{f}").get("size", 0)
|
|
bits = 16 if "fp16" in f.lower() or ext == ".safetensors" else 4 if "q4" in f.lower() else 8
|
|
context_length = 128000 if "128k" in f.lower() else 4096
|
|
|
|
variants.append({
|
|
"id": Path(f).stem,
|
|
"label": f,
|
|
"bits": bits,
|
|
"context_length": context_length,
|
|
"size_bytes": size_bytes,
|
|
"hf_repo": f"https://huggingface.co/{repo_id}",
|
|
"files": [f]
|
|
})
|
|
|
|
# ----- 6. Handle date -----
|
|
last_modified = model_info.lastModified
|
|
if isinstance(last_modified, str):
|
|
last_modified = datetime.fromisoformat(last_modified.replace("Z", "+00:00"))
|
|
|
|
# ----- 7. YAML data -----
|
|
yaml_data = {
|
|
"model": {
|
|
"name": repo_id.split("/")[-1],
|
|
"display_name": model_card.get("title", repo_id),
|
|
"description": model_card.get("summary", "No description available."),
|
|
"publisher_original": model_card.get("license", "other"),
|
|
"publisher_quantized": "Community",
|
|
"license": model_card.get("license", "other"),
|
|
"license_url": f"https://huggingface.co/{repo_id}/blob/main/{license_file}" if license_file else "N/A",
|
|
"publish_date": last_modified.date().isoformat(),
|
|
"modality": "text",
|
|
"thinking_model": True,
|
|
"tokenizer": {"files": tokenizer_files},
|
|
"architecture": model_card.get("model_architecture", "transformer"),
|
|
"formats": [{
|
|
"type": "gguf" if any(f.endswith(".gguf") for f in model_files) else "safetensors",
|
|
"variants": variants
|
|
}]
|
|
}
|
|
}
|
|
|
|
with open(out_path / "model.yaml", "w") as f:
|
|
yaml.dump(yaml_data, f, sort_keys=False)
|
|
|
|
return str(out_path)
|
|
|
|
|
|
# -------- Entry point for CLI --------
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) != 3:
|
|
print("Usage: python generate_model_yaml.py <huggingface/repo-id> <output-folder>")
|
|
sys.exit(1)
|
|
|
|
repo_id = sys.argv[1]
|
|
output_dir = sys.argv[2]
|
|
|
|
output_path = generate_model_bundle(repo_id, output_dir)
|
|
print(f"✅ Model bundle generated at: {output_path}")
|