generated from nhcarrigan/template
Compare commits
1 Commits
v0.1.0-alpha
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| dd14cdc421 |
@@ -1,2 +0,0 @@
|
|||||||
[target.x86_64-pc-windows-msvc]
|
|
||||||
linker = "lld-link"
|
|
||||||
@@ -6,5 +6,3 @@
|
|||||||
# Ignore binary files >:(
|
# Ignore binary files >:(
|
||||||
*.png binary
|
*.png binary
|
||||||
*.jpg binary
|
*.jpg binary
|
||||||
*.ico binary
|
|
||||||
*.icns binary
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
name: 🐛 Bug Report
|
name: 🐛 Bug Report
|
||||||
description: Something isn't working as expected? Let us know!
|
description: Something isn't working as expected? Let us know!
|
||||||
title: "[BUG] - "
|
title: '[BUG] - '
|
||||||
labels:
|
labels:
|
||||||
- "status/awaiting triage"
|
- "status/awaiting triage"
|
||||||
body:
|
body:
|
||||||
@@ -50,7 +50,7 @@ body:
|
|||||||
description: The operating system you are using, including the version/build number.
|
description: The operating system you are using, including the version/build number.
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
# Remove this section for non-web apps.
|
# Remove this section for non-web apps.
|
||||||
- type: input
|
- type: input
|
||||||
id: browser
|
id: browser
|
||||||
attributes:
|
attributes:
|
||||||
@@ -66,3 +66,4 @@ body:
|
|||||||
- No
|
- No
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
name: 💭 Feature Proposal
|
name: 💭 Feature Proposal
|
||||||
description: Have an idea for how we can improve? Share it here!
|
description: Have an idea for how we can improve? Share it here!
|
||||||
title: "[FEAT] - "
|
title: '[FEAT] - '
|
||||||
labels:
|
labels:
|
||||||
- "status/awaiting triage"
|
- "status/awaiting triage"
|
||||||
body:
|
body:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
name: ❓ Other Issue
|
name: ❓ Other Issue
|
||||||
description: I have something that is neither a bug nor a feature request.
|
description: I have something that is neither a bug nor a feature request.
|
||||||
title: "[OTHER] - "
|
title: '[OTHER] - '
|
||||||
labels:
|
labels:
|
||||||
- "status/awaiting triage"
|
- "status/awaiting triage"
|
||||||
body:
|
body:
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
name: CI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [main]
|
|
||||||
pull_request:
|
|
||||||
branches: [main]
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
lint-and-test:
|
|
||||||
name: Lint & Test
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install Linux dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y \
|
|
||||||
libwebkit2gtk-4.1-dev \
|
|
||||||
librsvg2-dev \
|
|
||||||
patchelf \
|
|
||||||
libgtk-3-dev \
|
|
||||||
libayatana-appindicator3-dev \
|
|
||||||
libasound2-dev \
|
|
||||||
pkg-config \
|
|
||||||
libclang-dev \
|
|
||||||
cmake
|
|
||||||
|
|
||||||
- name: Setup pnpm
|
|
||||||
uses: pnpm/action-setup@v4
|
|
||||||
with:
|
|
||||||
version: 9
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: 22
|
|
||||||
cache: pnpm
|
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
|
||||||
run: pnpm install
|
|
||||||
|
|
||||||
- name: Run ESLint
|
|
||||||
run: pnpm lint
|
|
||||||
|
|
||||||
- name: Run Prettier check
|
|
||||||
run: pnpm format:check
|
|
||||||
|
|
||||||
- name: Build frontend
|
|
||||||
run: pnpm build
|
|
||||||
|
|
||||||
- name: Run frontend tests
|
|
||||||
run: pnpm test
|
|
||||||
|
|
||||||
- name: Setup Rust
|
|
||||||
uses: dtolnay/rust-toolchain@stable
|
|
||||||
with:
|
|
||||||
components: clippy
|
|
||||||
|
|
||||||
- name: Cache Rust dependencies
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
~/.cargo/bin/
|
|
||||||
~/.cargo/registry/index/
|
|
||||||
~/.cargo/registry/cache/
|
|
||||||
~/.cargo/git/db/
|
|
||||||
src-tauri/target/
|
|
||||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
|
||||||
|
|
||||||
- name: Run Clippy
|
|
||||||
working-directory: src-tauri
|
|
||||||
run: cargo clippy --all-targets --all-features -- -D warnings
|
|
||||||
|
|
||||||
- name: Run Rust tests
|
|
||||||
working-directory: src-tauri
|
|
||||||
run: cargo test
|
|
||||||
|
|
||||||
build-linux:
|
|
||||||
name: Build Linux
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: lint-and-test
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install Linux dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y \
|
|
||||||
libwebkit2gtk-4.1-dev \
|
|
||||||
librsvg2-dev \
|
|
||||||
patchelf \
|
|
||||||
libgtk-3-dev \
|
|
||||||
libayatana-appindicator3-dev \
|
|
||||||
libasound2-dev \
|
|
||||||
pkg-config \
|
|
||||||
libclang-dev \
|
|
||||||
cmake \
|
|
||||||
xdg-utils
|
|
||||||
|
|
||||||
- name: Setup pnpm
|
|
||||||
uses: pnpm/action-setup@v4
|
|
||||||
with:
|
|
||||||
version: 9
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: 22
|
|
||||||
cache: pnpm
|
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
|
||||||
run: pnpm install
|
|
||||||
|
|
||||||
- name: Setup Rust
|
|
||||||
uses: dtolnay/rust-toolchain@stable
|
|
||||||
|
|
||||||
- name: Cache Rust dependencies
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
~/.cargo/bin/
|
|
||||||
~/.cargo/registry/index/
|
|
||||||
~/.cargo/registry/cache/
|
|
||||||
~/.cargo/git/db/
|
|
||||||
src-tauri/target/
|
|
||||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
|
||||||
|
|
||||||
- name: Build Linux
|
|
||||||
run: pnpm build:linux
|
|
||||||
|
|
||||||
build-windows:
|
|
||||||
name: Build Windows (cross-compile)
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: lint-and-test
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install Linux dependencies for cross-compilation
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y \
|
|
||||||
libwebkit2gtk-4.1-dev \
|
|
||||||
librsvg2-dev \
|
|
||||||
patchelf \
|
|
||||||
libgtk-3-dev \
|
|
||||||
libayatana-appindicator3-dev \
|
|
||||||
libasound2-dev \
|
|
||||||
pkg-config \
|
|
||||||
libclang-dev \
|
|
||||||
cmake \
|
|
||||||
clang \
|
|
||||||
lld \
|
|
||||||
llvm \
|
|
||||||
nsis
|
|
||||||
|
|
||||||
- name: Setup pnpm
|
|
||||||
uses: pnpm/action-setup@v4
|
|
||||||
with:
|
|
||||||
version: 9
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: 22
|
|
||||||
cache: pnpm
|
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
|
||||||
run: pnpm install
|
|
||||||
|
|
||||||
- name: Setup Rust
|
|
||||||
uses: dtolnay/rust-toolchain@stable
|
|
||||||
with:
|
|
||||||
targets: x86_64-pc-windows-msvc
|
|
||||||
|
|
||||||
- name: Install cargo-xwin
|
|
||||||
run: |
|
|
||||||
curl -fsSL https://github.com/rust-cross/cargo-xwin/releases/download/v0.20.2/cargo-xwin-v0.20.2.x86_64-unknown-linux-musl.tar.gz | tar xz
|
|
||||||
sudo mv cargo-xwin /usr/local/bin/
|
|
||||||
|
|
||||||
- name: Cache Rust dependencies
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
~/.cargo/bin/
|
|
||||||
~/.cargo/registry/index/
|
|
||||||
~/.cargo/registry/cache/
|
|
||||||
~/.cargo/git/db/
|
|
||||||
src-tauri/target/
|
|
||||||
key: ${{ runner.os }}-cargo-windows-${{ hashFiles('**/Cargo.lock') }}
|
|
||||||
|
|
||||||
- name: Build Windows
|
|
||||||
run: pnpm build:windows
|
|
||||||
@@ -2,11 +2,11 @@ name: Security Scan and Upload
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [ main ]
|
||||||
schedule:
|
schedule:
|
||||||
- cron: "0 0 * * 1"
|
- cron: '0 0 * * 1'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
-61
@@ -1,61 +0,0 @@
|
|||||||
# Logs
|
|
||||||
logs
|
|
||||||
*.log
|
|
||||||
npm-debug.log*
|
|
||||||
yarn-debug.log*
|
|
||||||
yarn-error.log*
|
|
||||||
pnpm-debug.log*
|
|
||||||
lerna-debug.log*
|
|
||||||
|
|
||||||
node_modules
|
|
||||||
dist
|
|
||||||
dist-ssr
|
|
||||||
*.local
|
|
||||||
|
|
||||||
# Editor directories and files
|
|
||||||
.vscode/*
|
|
||||||
!.vscode/extensions.json
|
|
||||||
.idea
|
|
||||||
.DS_Store
|
|
||||||
*.suo
|
|
||||||
*.ntvs*
|
|
||||||
*.njsproj
|
|
||||||
*.sln
|
|
||||||
*.sw?
|
|
||||||
|
|
||||||
# Python
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
*.so
|
|
||||||
.Python
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
.venv/
|
|
||||||
*.egg-info/
|
|
||||||
|
|
||||||
# Models - large ML model files
|
|
||||||
models/
|
|
||||||
src/pretrained_models/
|
|
||||||
src-tauri/resources/models/
|
|
||||||
*.gguf
|
|
||||||
*.bin
|
|
||||||
|
|
||||||
# Tauri
|
|
||||||
src-tauri/target/
|
|
||||||
src-tauri/WixTools/
|
|
||||||
src-tauri/resources/
|
|
||||||
|
|
||||||
# Build outputs
|
|
||||||
build/
|
|
||||||
|
|
||||||
# App data
|
|
||||||
recordings/
|
|
||||||
transcripts/
|
|
||||||
summaries/
|
|
||||||
|
|
||||||
# Environment
|
|
||||||
.env
|
|
||||||
*.env.local
|
|
||||||
prod.env
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
build/
|
|
||||||
.svelte-kit/
|
|
||||||
dist/
|
|
||||||
src-tauri/target/
|
|
||||||
src-tauri/gen/
|
|
||||||
node_modules/
|
|
||||||
.pnpm-store/
|
|
||||||
pnpm-lock.yaml
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
{
|
|
||||||
"semi": true,
|
|
||||||
"singleQuote": false,
|
|
||||||
"tabWidth": 2,
|
|
||||||
"trailingComma": "es5",
|
|
||||||
"printWidth": 100
|
|
||||||
}
|
|
||||||
Vendored
-3
@@ -1,3 +0,0 @@
|
|||||||
{
|
|
||||||
"recommendations": ["tauri-apps.tauri-vscode", "rust-lang.rust-analyzer"]
|
|
||||||
}
|
|
||||||
-120
@@ -1,120 +0,0 @@
|
|||||||
# Chronara - Local Meeting Transcription & Summarization
|
|
||||||
|
|
||||||
A Windows desktop application that transcribes, diarizes, and summarizes meetings using only locally-running models.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- 🎙️ Real-time audio transcription with speaker diarization (WhisperX)
|
|
||||||
- 📝 Intelligent meeting summarization (Llama 3.2)
|
|
||||||
- 🖥️ Everything runs locally - no cloud services required
|
|
||||||
- 📦 All models bundled - no separate downloads needed
|
|
||||||
|
|
||||||
## Tech Stack
|
|
||||||
|
|
||||||
- **Transcription**: WhisperX (Whisper + speaker diarization)
|
|
||||||
- **Summarization**: Llama 3.2 1B/3B
|
|
||||||
- **Backend**: Python with FastAPI
|
|
||||||
- **Frontend**: Tauri + React
|
|
||||||
- **Model Runtime**: llama-cpp-python
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
chronara/
|
|
||||||
├── src/
|
|
||||||
│ ├── backend/ # Python FastAPI backend
|
|
||||||
│ ├── components/ # React components
|
|
||||||
│ └── App.tsx # Main React app
|
|
||||||
├── src-tauri/ # Tauri configuration
|
|
||||||
├── models/ # Bundled model files
|
|
||||||
├── scripts/ # Build and setup scripts
|
|
||||||
└── assets/ # Icons, resources
|
|
||||||
```
|
|
||||||
|
|
||||||
## Development Setup
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
- Node.js 18+ with pnpm
|
|
||||||
- Python 3.10+
|
|
||||||
- Rust (for Tauri)
|
|
||||||
- Windows build tools (for native modules)
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
1. Clone the repository:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/naomi-lgbt/chronara.git
|
|
||||||
cd chronara
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Install frontend dependencies:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pnpm install
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Install Python dependencies:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
4. Download the AI models:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/download_models.py
|
|
||||||
```
|
|
||||||
|
|
||||||
5. Run in development mode:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pnpm tauri:dev
|
|
||||||
```
|
|
||||||
|
|
||||||
## Building for Production
|
|
||||||
|
|
||||||
### Windows
|
|
||||||
|
|
||||||
1. Download models if not already done:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/download_models.py
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Build the Windows executable:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/build_windows.py
|
|
||||||
```
|
|
||||||
|
|
||||||
The installer will be created in `src-tauri/target/release/bundle/nsis/`.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
1. **Start Recording**: Click the "Start Recording" button to begin capturing audio
|
|
||||||
2. **Real-time Transcription**: Watch as the conversation is transcribed with speaker labels
|
|
||||||
3. **Generate Summary**: After recording, click "Generate Summary" for an AI-powered meeting summary
|
|
||||||
4. **Export**: Download both the full transcript and summary as text files
|
|
||||||
|
|
||||||
## Model Information
|
|
||||||
|
|
||||||
### Transcription (WhisperX)
|
|
||||||
|
|
||||||
- **Model**: OpenAI Whisper base model with WhisperX enhancements
|
|
||||||
- **Features**: Speaker diarization, timestamp alignment
|
|
||||||
- **Size**: ~150MB
|
|
||||||
|
|
||||||
### Summarization (Llama 3.2)
|
|
||||||
|
|
||||||
- **1B Model**: Fast, good for basic summaries (~1.2GB)
|
|
||||||
- **3B Model**: Better quality summaries (~2.5GB)
|
|
||||||
- **Format**: GGUF quantized models
|
|
||||||
|
|
||||||
## Privacy & Security
|
|
||||||
|
|
||||||
- All processing happens locally on your machine
|
|
||||||
- No audio or text data is sent to external servers
|
|
||||||
- Models are bundled with the application
|
|
||||||
- Meeting data stays on your device
|
|
||||||
@@ -1,6 +1,16 @@
|
|||||||
# Chronara
|
# New Repository Template
|
||||||
|
|
||||||
A meeting transcription and summarisation tool that uses 100% local models.
|
This template contains all of our basic files for a new GitHub repository. There is also a handy workflow that will create an issue on a new repository made from this template, with a checklist for the steps we usually take in setting up a new repository.
|
||||||
|
|
||||||
|
If you're starting a Node.JS project with TypeScript, we have a [specific template](https://github.com/naomi-lgbt/nodejs-typescript-template) for that purpose.
|
||||||
|
|
||||||
|
## Readme
|
||||||
|
|
||||||
|
Delete all of the above text (including this line), and uncomment the below text to use our standard readme template.
|
||||||
|
|
||||||
|
<!-- # Project Name
|
||||||
|
|
||||||
|
Project Description
|
||||||
|
|
||||||
## Live Version
|
## Live Version
|
||||||
|
|
||||||
@@ -8,7 +18,7 @@ This page is currently deployed. [View the live website.]
|
|||||||
|
|
||||||
## Feedback and Bugs
|
## Feedback and Bugs
|
||||||
|
|
||||||
If you have feedback or a bug report, please feel free to open a GitHub issue!
|
If you have feedback or a bug report, please [log a ticket on our forum](https://support.nhcarrigan.com).
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
@@ -26,4 +36,4 @@ Copyright held by Naomi Carrigan.
|
|||||||
|
|
||||||
## Contact
|
## Contact
|
||||||
|
|
||||||
We may be contacted through our [Chat Server](http://chat.nhcarrigan.com) or via email at `contact@nhcarrigan.com`.
|
We may be contacted through our [Chat Server](http://chat.nhcarrigan.com) or via email at `contact@nhcarrigan.com`. -->
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
echo "🔍 Running all checks..."
|
|
||||||
echo "========================================"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "📦 Installing dependencies..."
|
|
||||||
pnpm install
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "🔎 Running ESLint..."
|
|
||||||
pnpm lint
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "💅 Running Prettier check..."
|
|
||||||
pnpm format:check
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "🏗️ Building frontend..."
|
|
||||||
pnpm build
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "🧪 Running frontend tests..."
|
|
||||||
pnpm test
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "🦀 Running Clippy..."
|
|
||||||
cd src-tauri
|
|
||||||
cargo clippy --all-targets --all-features -- -D warnings
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "🧪 Running Rust tests..."
|
|
||||||
cargo test
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "========================================"
|
|
||||||
echo "✅ All checks passed!"
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
import js from "@eslint/js";
|
|
||||||
import tseslint from "typescript-eslint";
|
|
||||||
import reactHooks from "eslint-plugin-react-hooks";
|
|
||||||
import reactRefresh from "eslint-plugin-react-refresh";
|
|
||||||
import prettier from "eslint-config-prettier";
|
|
||||||
import globals from "globals";
|
|
||||||
|
|
||||||
export default tseslint.config(
|
|
||||||
js.configs.recommended,
|
|
||||||
...tseslint.configs.recommended,
|
|
||||||
prettier,
|
|
||||||
{
|
|
||||||
languageOptions: {
|
|
||||||
globals: {
|
|
||||||
...globals.browser,
|
|
||||||
...globals.node,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
files: ["**/*.{ts,tsx}"],
|
|
||||||
plugins: {
|
|
||||||
"react-hooks": reactHooks,
|
|
||||||
"react-refresh": reactRefresh,
|
|
||||||
},
|
|
||||||
rules: {
|
|
||||||
...reactHooks.configs.recommended.rules,
|
|
||||||
"react-refresh/only-export-components": ["warn", { allowConstantExport: true }],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ignores: ["build/", "dist/", "src-tauri/target/", "node_modules/"],
|
|
||||||
}
|
|
||||||
);
|
|
||||||
-14
@@ -1,14 +0,0 @@
|
|||||||
<!doctype html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8" />
|
|
||||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
||||||
<title>Tauri + React + Typescript</title>
|
|
||||||
</head>
|
|
||||||
|
|
||||||
<body>
|
|
||||||
<div id="root"></div>
|
|
||||||
<script type="module" src="/src/main.tsx"></script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "chronara",
|
|
||||||
"private": true,
|
|
||||||
"version": "0.1.0",
|
|
||||||
"type": "module",
|
|
||||||
"scripts": {
|
|
||||||
"dev": "vite",
|
|
||||||
"build": "tsc && vite build",
|
|
||||||
"lint": "eslint src",
|
|
||||||
"lint:fix": "eslint src --fix",
|
|
||||||
"format": "prettier --write .",
|
|
||||||
"format:check": "prettier --check .",
|
|
||||||
"preview": "vite preview",
|
|
||||||
"tauri": "tauri",
|
|
||||||
"tauri:dev": "tauri dev",
|
|
||||||
"build:linux": "tauri build",
|
|
||||||
"build:windows": "./scripts/build-windows-nsis.sh",
|
|
||||||
"build:all": "pnpm build:linux && pnpm build:windows",
|
|
||||||
"test": "vitest run",
|
|
||||||
"test:watch": "vitest",
|
|
||||||
"test:coverage": "vitest run --coverage"
|
|
||||||
},
|
|
||||||
"dependencies": {
|
|
||||||
"@tauri-apps/api": "^2",
|
|
||||||
"@tauri-apps/plugin-opener": "^2",
|
|
||||||
"react": "^19.1.0",
|
|
||||||
"react-dom": "^19.1.0"
|
|
||||||
},
|
|
||||||
"devDependencies": {
|
|
||||||
"@eslint/js": "^9.19.0",
|
|
||||||
"@tauri-apps/cli": "^2",
|
|
||||||
"@testing-library/jest-dom": "^6.9.1",
|
|
||||||
"@testing-library/react": "^16.3.0",
|
|
||||||
"@types/react": "^19.1.8",
|
|
||||||
"@types/react-dom": "^19.1.6",
|
|
||||||
"@vitejs/plugin-react": "^4.6.0",
|
|
||||||
"eslint": "^9.19.0",
|
|
||||||
"eslint-config-prettier": "^10.1.8",
|
|
||||||
"eslint-plugin-react-hooks": "^5.2.0",
|
|
||||||
"eslint-plugin-react-refresh": "^0.4.20",
|
|
||||||
"globals": "^17.0.0",
|
|
||||||
"jsdom": "^27.4.0",
|
|
||||||
"prettier": "^3.8.0",
|
|
||||||
"typescript": "~5.8.3",
|
|
||||||
"typescript-eslint": "^8.53.0",
|
|
||||||
"vite": "^7.0.4",
|
|
||||||
"vitest": "^4.0.17"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Generated
-373
@@ -1,373 +0,0 @@
|
|||||||
# This file is automatically @generated by Cargo.
|
|
||||||
# It is not intended for manual editing.
|
|
||||||
version = 4
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "aho-corasick"
|
|
||||||
version = "1.1.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
|
||||||
dependencies = [
|
|
||||||
"memchr",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "bindgen"
|
|
||||||
version = "0.72.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags",
|
|
||||||
"cexpr",
|
|
||||||
"clang-sys",
|
|
||||||
"itertools",
|
|
||||||
"log",
|
|
||||||
"prettyplease",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"regex",
|
|
||||||
"rustc-hash",
|
|
||||||
"shlex",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "bitflags"
|
|
||||||
version = "2.5.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cc"
|
|
||||||
version = "1.2.49"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215"
|
|
||||||
dependencies = [
|
|
||||||
"find-msvc-tools",
|
|
||||||
"jobserver",
|
|
||||||
"libc",
|
|
||||||
"shlex",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cexpr"
|
|
||||||
version = "0.6.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
|
||||||
dependencies = [
|
|
||||||
"nom",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cfg-if"
|
|
||||||
version = "1.0.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "clang-sys"
|
|
||||||
version = "1.8.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
|
||||||
dependencies = [
|
|
||||||
"glob",
|
|
||||||
"libc",
|
|
||||||
"libloading",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cmake"
|
|
||||||
version = "0.1.56"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b042e5d8a74ae91bb0961acd039822472ec99f8ab0948cbf6d1369588f8be586"
|
|
||||||
dependencies = [
|
|
||||||
"cc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "either"
|
|
||||||
version = "1.12.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "find-msvc-tools"
|
|
||||||
version = "0.1.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "find_cuda_helper"
|
|
||||||
version = "0.2.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad"
|
|
||||||
dependencies = [
|
|
||||||
"glob",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "glob"
|
|
||||||
version = "0.3.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "itertools"
|
|
||||||
version = "0.12.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
|
||||||
dependencies = [
|
|
||||||
"either",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "jobserver"
|
|
||||||
version = "0.1.31"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e"
|
|
||||||
dependencies = [
|
|
||||||
"libc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "libc"
|
|
||||||
version = "0.2.155"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "libloading"
|
|
||||||
version = "0.8.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"windows-targets",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "llama-cpp-sys-2"
|
|
||||||
version = "0.1.132"
|
|
||||||
dependencies = [
|
|
||||||
"bindgen",
|
|
||||||
"cc",
|
|
||||||
"cmake",
|
|
||||||
"find_cuda_helper",
|
|
||||||
"glob",
|
|
||||||
"walkdir",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "log"
|
|
||||||
version = "0.4.21"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "memchr"
|
|
||||||
version = "2.7.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "minimal-lexical"
|
|
||||||
version = "0.2.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "nom"
|
|
||||||
version = "7.1.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
|
|
||||||
dependencies = [
|
|
||||||
"memchr",
|
|
||||||
"minimal-lexical",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "prettyplease"
|
|
||||||
version = "0.2.20"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "proc-macro2"
|
|
||||||
version = "1.0.85"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23"
|
|
||||||
dependencies = [
|
|
||||||
"unicode-ident",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "quote"
|
|
||||||
version = "1.0.36"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "regex"
|
|
||||||
version = "1.10.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
|
|
||||||
dependencies = [
|
|
||||||
"aho-corasick",
|
|
||||||
"memchr",
|
|
||||||
"regex-automata",
|
|
||||||
"regex-syntax",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "regex-automata"
|
|
||||||
version = "0.4.7"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
|
||||||
dependencies = [
|
|
||||||
"aho-corasick",
|
|
||||||
"memchr",
|
|
||||||
"regex-syntax",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "regex-syntax"
|
|
||||||
version = "0.8.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rustc-hash"
|
|
||||||
version = "2.1.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "same-file"
|
|
||||||
version = "1.0.6"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
|
||||||
dependencies = [
|
|
||||||
"winapi-util",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "shlex"
|
|
||||||
version = "1.3.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "syn"
|
|
||||||
version = "2.0.87"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"unicode-ident",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "unicode-ident"
|
|
||||||
version = "1.0.12"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "walkdir"
|
|
||||||
version = "2.5.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
|
||||||
dependencies = [
|
|
||||||
"same-file",
|
|
||||||
"winapi-util",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "winapi-util"
|
|
||||||
version = "0.1.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
|
||||||
dependencies = [
|
|
||||||
"windows-sys",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows-sys"
|
|
||||||
version = "0.52.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
|
||||||
dependencies = [
|
|
||||||
"windows-targets",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows-targets"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
|
|
||||||
dependencies = [
|
|
||||||
"windows_aarch64_gnullvm",
|
|
||||||
"windows_aarch64_msvc",
|
|
||||||
"windows_i686_gnu",
|
|
||||||
"windows_i686_gnullvm",
|
|
||||||
"windows_i686_msvc",
|
|
||||||
"windows_x86_64_gnu",
|
|
||||||
"windows_x86_64_gnullvm",
|
|
||||||
"windows_x86_64_msvc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_aarch64_gnullvm"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_aarch64_msvc"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_i686_gnu"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_i686_gnullvm"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_i686_msvc"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_x86_64_gnu"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_x86_64_gnullvm"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_x86_64_msvc"
|
|
||||||
version = "0.52.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
|
||||||
#
|
|
||||||
# When uploading crates to the registry Cargo will automatically
|
|
||||||
# "normalize" Cargo.toml files for maximal compatibility
|
|
||||||
# with all versions of Cargo and also rewrite `path` dependencies
|
|
||||||
# to registry (e.g., crates.io) dependencies.
|
|
||||||
#
|
|
||||||
# If you are reading this file be aware that the original Cargo.toml
|
|
||||||
# will likely look very different (and much more reasonable).
|
|
||||||
# See Cargo.toml.orig for the original contents.
|
|
||||||
|
|
||||||
[package]
|
|
||||||
edition = "2021"
|
|
||||||
name = "llama-cpp-sys-2"
|
|
||||||
version = "0.1.132"
|
|
||||||
build = "build.rs"
|
|
||||||
links = "llama"
|
|
||||||
include = [
|
|
||||||
"wrapper.h",
|
|
||||||
"wrapper_mtmd.h",
|
|
||||||
"build.rs",
|
|
||||||
"/src",
|
|
||||||
"/llama.cpp/common/**/*.h",
|
|
||||||
"/llama.cpp/common/**/*.hpp",
|
|
||||||
"/llama.cpp/common/**/*.cpp",
|
|
||||||
"/llama.cpp/ggml/include/*.h",
|
|
||||||
"/llama.cpp/ggml/src/*.h",
|
|
||||||
"/llama.cpp/ggml/src/*.c",
|
|
||||||
"/llama.cpp/ggml/src/*.cpp",
|
|
||||||
"/llama.cpp/src/*.h",
|
|
||||||
"/llama.cpp/src/*.cpp",
|
|
||||||
"/llama.cpp/src/models/*.h",
|
|
||||||
"/llama.cpp/src/models/*.cpp",
|
|
||||||
"/llama.cpp/tools/mtmd/*.h",
|
|
||||||
"/llama.cpp/tools/mtmd/*.cpp",
|
|
||||||
"/llama.cpp/convert_hf_to_gguf.py",
|
|
||||||
"/llama.cpp/common/build-info.cpp.in",
|
|
||||||
"/llama.cpp/ggml/src/ggml-cuda.cu",
|
|
||||||
"/llama.cpp/ggml/src/ggml-metal.m",
|
|
||||||
"/llama.cpp/ggml/src/ggml-metal.metal",
|
|
||||||
"/llama.cpp/include/llama.h",
|
|
||||||
"/llama.cpp/include/llama-cpp.h",
|
|
||||||
"/llama.cpp/ggml/src/ggml-cpu/**/*",
|
|
||||||
"/llama.cpp/ggml/src/ggml-cuda/**/*",
|
|
||||||
"/llama.cpp/ggml/src/ggml-metal/**/*",
|
|
||||||
"/llama.cpp/ggml/src/ggml-vulkan/**/*",
|
|
||||||
"/llama.cpp/ggml/src/llamafile/sgemm.h",
|
|
||||||
"/llama.cpp/ggml/src/llamafile/sgemm.cpp",
|
|
||||||
"/llama.cpp/pocs",
|
|
||||||
"/llama.cpp/vendor",
|
|
||||||
"/llama.cpp/CMakeLists.txt",
|
|
||||||
"/llama.cpp/common/CMakeLists.txt",
|
|
||||||
"/llama.cpp/ggml/CMakeLists.txt",
|
|
||||||
"/llama.cpp/ggml/src/CMakeLists.txt",
|
|
||||||
"/llama.cpp/src/CMakeLists.txt",
|
|
||||||
"/llama.cpp/cmake",
|
|
||||||
"/llama.cpp/ggml/cmake",
|
|
||||||
"/llama.cpp/common/cmake",
|
|
||||||
]
|
|
||||||
autolib = false
|
|
||||||
autobins = false
|
|
||||||
autoexamples = false
|
|
||||||
autotests = false
|
|
||||||
autobenches = false
|
|
||||||
description = "Low Level Bindings to llama.cpp"
|
|
||||||
readme = "README.md"
|
|
||||||
license = "MIT OR Apache-2.0"
|
|
||||||
repository = "https://github.com/utilityai/llama-cpp-rs"
|
|
||||||
|
|
||||||
[features]
|
|
||||||
cuda = []
|
|
||||||
cuda-no-vmm = ["cuda"]
|
|
||||||
dynamic-link = []
|
|
||||||
metal = []
|
|
||||||
mtmd = []
|
|
||||||
openmp = []
|
|
||||||
shared-stdcxx = []
|
|
||||||
system-ggml = []
|
|
||||||
vulkan = []
|
|
||||||
|
|
||||||
[lib]
|
|
||||||
name = "llama_cpp_sys_2"
|
|
||||||
path = "src/lib.rs"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
|
|
||||||
[build-dependencies.bindgen]
|
|
||||||
version = "0.72.1"
|
|
||||||
|
|
||||||
[build-dependencies.cc]
|
|
||||||
version = "1.2.49"
|
|
||||||
features = ["parallel"]
|
|
||||||
|
|
||||||
[build-dependencies.cmake]
|
|
||||||
version = "0.1"
|
|
||||||
|
|
||||||
[build-dependencies.find_cuda_helper]
|
|
||||||
version = "0.2.0"
|
|
||||||
|
|
||||||
[build-dependencies.glob]
|
|
||||||
version = "0.3.3"
|
|
||||||
|
|
||||||
[build-dependencies.walkdir]
|
|
||||||
version = "2"
|
|
||||||
Generated
-85
@@ -1,85 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "llama-cpp-sys-2"
|
|
||||||
description = "Low Level Bindings to llama.cpp"
|
|
||||||
version = "0.1.132"
|
|
||||||
edition = "2021"
|
|
||||||
license = "MIT OR Apache-2.0"
|
|
||||||
repository = "https://github.com/utilityai/llama-cpp-rs"
|
|
||||||
links = "llama"
|
|
||||||
|
|
||||||
include = [
|
|
||||||
"wrapper.h",
|
|
||||||
"wrapper_mtmd.h",
|
|
||||||
"build.rs",
|
|
||||||
"/src",
|
|
||||||
|
|
||||||
"/llama.cpp/common/**/*.h",
|
|
||||||
"/llama.cpp/common/**/*.hpp",
|
|
||||||
"/llama.cpp/common/**/*.cpp",
|
|
||||||
"/llama.cpp/ggml/include/*.h",
|
|
||||||
"/llama.cpp/ggml/src/*.h",
|
|
||||||
"/llama.cpp/ggml/src/*.c",
|
|
||||||
"/llama.cpp/ggml/src/*.cpp",
|
|
||||||
"/llama.cpp/src/*.h",
|
|
||||||
"/llama.cpp/src/*.cpp",
|
|
||||||
"/llama.cpp/src/models/*.h",
|
|
||||||
"/llama.cpp/src/models/*.cpp",
|
|
||||||
"/llama.cpp/tools/mtmd/*.h",
|
|
||||||
"/llama.cpp/tools/mtmd/*.cpp",
|
|
||||||
|
|
||||||
"/llama.cpp/convert_hf_to_gguf.py", # Yes, it's required
|
|
||||||
"/llama.cpp/common/build-info.cpp.in",
|
|
||||||
|
|
||||||
"/llama.cpp/ggml/src/ggml-cuda.cu",
|
|
||||||
"/llama.cpp/ggml/src/ggml-metal.m",
|
|
||||||
"/llama.cpp/ggml/src/ggml-metal.metal",
|
|
||||||
|
|
||||||
"/llama.cpp/include/llama.h",
|
|
||||||
"/llama.cpp/include/llama-cpp.h",
|
|
||||||
|
|
||||||
"/llama.cpp/ggml/src/ggml-cpu/**/*",
|
|
||||||
"/llama.cpp/ggml/src/ggml-cuda/**/*",
|
|
||||||
"/llama.cpp/ggml/src/ggml-metal/**/*",
|
|
||||||
"/llama.cpp/ggml/src/ggml-vulkan/**/*",
|
|
||||||
|
|
||||||
"/llama.cpp/ggml/src/llamafile/sgemm.h",
|
|
||||||
"/llama.cpp/ggml/src/llamafile/sgemm.cpp",
|
|
||||||
|
|
||||||
"/llama.cpp/pocs",
|
|
||||||
"/llama.cpp/vendor",
|
|
||||||
|
|
||||||
"/llama.cpp/CMakeLists.txt",
|
|
||||||
"/llama.cpp/common/CMakeLists.txt",
|
|
||||||
"/llama.cpp/ggml/CMakeLists.txt",
|
|
||||||
"/llama.cpp/ggml/src/CMakeLists.txt",
|
|
||||||
"/llama.cpp/src/CMakeLists.txt",
|
|
||||||
|
|
||||||
"/llama.cpp/cmake",
|
|
||||||
"/llama.cpp/ggml/cmake",
|
|
||||||
"/llama.cpp/common/cmake",
|
|
||||||
]
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
|
|
||||||
[build-dependencies]
|
|
||||||
bindgen = { workspace = true }
|
|
||||||
cc = { workspace = true, features = ["parallel"] }
|
|
||||||
cmake = "0.1"
|
|
||||||
find_cuda_helper = "0.2.0"
|
|
||||||
glob = "0.3.3"
|
|
||||||
walkdir = "2"
|
|
||||||
|
|
||||||
[features]
|
|
||||||
cuda = []
|
|
||||||
# Disables the need to dynamically link against libcuda.so / cuda.dll
|
|
||||||
cuda-no-vmm = ["cuda"]
|
|
||||||
metal = []
|
|
||||||
dynamic-link = []
|
|
||||||
vulkan = []
|
|
||||||
openmp = []
|
|
||||||
# Only has an impact on Android.
|
|
||||||
shared-stdcxx = []
|
|
||||||
system-ggml = []
|
|
||||||
mtmd = []
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
# llama-cpp-sys
|
|
||||||
|
|
||||||
Raw bindings to llama.cpp with cuda support.
|
|
||||||
|
|
||||||
See [llama-cpp-2](https://crates.io/crates/llama-cpp-2) for a safe API.
|
|
||||||
@@ -1,952 +0,0 @@
|
|||||||
use cmake::Config;
|
|
||||||
use glob::glob;
|
|
||||||
use std::env;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use std::process::Command;
|
|
||||||
use walkdir::DirEntry;
|
|
||||||
|
|
||||||
enum WindowsVariant {
|
|
||||||
Msvc,
|
|
||||||
Other,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum AppleVariant {
|
|
||||||
MacOS,
|
|
||||||
Other,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum TargetOs {
|
|
||||||
Windows(WindowsVariant),
|
|
||||||
Apple(AppleVariant),
|
|
||||||
Linux,
|
|
||||||
Android,
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! debug_log {
|
|
||||||
($($arg:tt)*) => {
|
|
||||||
if std::env::var("BUILD_DEBUG").is_ok() {
|
|
||||||
println!("cargo:warning=[DEBUG] {}", format!($($arg)*));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_target_os() -> Result<(TargetOs, String), String> {
|
|
||||||
let target = env::var("TARGET").unwrap();
|
|
||||||
|
|
||||||
if target.contains("windows") {
|
|
||||||
if target.ends_with("-windows-msvc") {
|
|
||||||
Ok((TargetOs::Windows(WindowsVariant::Msvc), target))
|
|
||||||
} else {
|
|
||||||
Ok((TargetOs::Windows(WindowsVariant::Other), target))
|
|
||||||
}
|
|
||||||
} else if target.contains("apple") {
|
|
||||||
if target.ends_with("-apple-darwin") {
|
|
||||||
Ok((TargetOs::Apple(AppleVariant::MacOS), target))
|
|
||||||
} else {
|
|
||||||
Ok((TargetOs::Apple(AppleVariant::Other), target))
|
|
||||||
}
|
|
||||||
} else if target.contains("android")
|
|
||||||
|| target == "aarch64-linux-android"
|
|
||||||
|| target == "armv7-linux-androideabi"
|
|
||||||
|| target == "i686-linux-android"
|
|
||||||
|| target == "x86_64-linux-android"
|
|
||||||
{
|
|
||||||
// Handle both full android targets and short names like arm64-v8a that cargo ndk might use
|
|
||||||
Ok((TargetOs::Android, target))
|
|
||||||
} else if target.contains("linux") {
|
|
||||||
Ok((TargetOs::Linux, target))
|
|
||||||
} else {
|
|
||||||
Err(target)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_cargo_target_dir() -> Result<PathBuf, Box<dyn std::error::Error>> {
|
|
||||||
let out_dir = env::var("OUT_DIR")?;
|
|
||||||
let path = PathBuf::from(out_dir);
|
|
||||||
let target_dir = path
|
|
||||||
.ancestors()
|
|
||||||
.nth(3)
|
|
||||||
.ok_or("OUT_DIR is not deep enough")?;
|
|
||||||
Ok(target_dir.to_path_buf())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_lib_names(out_dir: &Path, build_shared_libs: bool) -> Vec<String> {
|
|
||||||
// Use CARGO_CFG_TARGET_OS to detect TARGET platform, not HOST
|
|
||||||
// This fixes cross-compilation from Linux to Windows
|
|
||||||
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
|
|
||||||
let lib_pattern = if target_os == "windows" {
|
|
||||||
"*.lib"
|
|
||||||
} else if target_os == "macos" {
|
|
||||||
if build_shared_libs {
|
|
||||||
"*.dylib"
|
|
||||||
} else {
|
|
||||||
"*.a"
|
|
||||||
}
|
|
||||||
} else if build_shared_libs {
|
|
||||||
"*.so"
|
|
||||||
} else {
|
|
||||||
"*.a"
|
|
||||||
};
|
|
||||||
let libs_dir = out_dir.join("lib*");
|
|
||||||
let pattern = libs_dir.join(lib_pattern);
|
|
||||||
debug_log!("Extract libs {}", pattern.display());
|
|
||||||
|
|
||||||
let mut lib_names: Vec<String> = Vec::new();
|
|
||||||
|
|
||||||
// Process the libraries based on the pattern
|
|
||||||
for entry in glob(pattern.to_str().unwrap()).unwrap() {
|
|
||||||
match entry {
|
|
||||||
Ok(path) => {
|
|
||||||
let stem = path.file_stem().unwrap();
|
|
||||||
let stem_str = stem.to_str().unwrap();
|
|
||||||
|
|
||||||
// Remove the "lib" prefix if present
|
|
||||||
let lib_name = if stem_str.starts_with("lib") {
|
|
||||||
stem_str.strip_prefix("lib").unwrap_or(stem_str)
|
|
||||||
} else {
|
|
||||||
if path.extension() == Some(std::ffi::OsStr::new("a")) {
|
|
||||||
let target = path.parent().unwrap().join(format!("lib{}.a", stem_str));
|
|
||||||
std::fs::rename(&path, &target).unwrap_or_else(|e| {
|
|
||||||
panic!("Failed to rename {path:?} to {target:?}: {e:?}");
|
|
||||||
})
|
|
||||||
}
|
|
||||||
stem_str
|
|
||||||
};
|
|
||||||
lib_names.push(lib_name.to_string());
|
|
||||||
}
|
|
||||||
Err(e) => println!("cargo:warning=error={}", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lib_names
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_lib_assets(out_dir: &Path) -> Vec<PathBuf> {
|
|
||||||
// Use CARGO_CFG_TARGET_OS to detect TARGET platform, not HOST
|
|
||||||
// This fixes cross-compilation from Linux to Windows
|
|
||||||
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
|
|
||||||
let shared_lib_pattern = if target_os == "windows" {
|
|
||||||
"*.dll"
|
|
||||||
} else if target_os == "macos" {
|
|
||||||
"*.dylib"
|
|
||||||
} else {
|
|
||||||
"*.so"
|
|
||||||
};
|
|
||||||
|
|
||||||
let shared_libs_dir = if target_os == "windows" { "bin" } else { "lib" };
|
|
||||||
let libs_dir = out_dir.join(shared_libs_dir);
|
|
||||||
let pattern = libs_dir.join(shared_lib_pattern);
|
|
||||||
debug_log!("Extract lib assets {}", pattern.display());
|
|
||||||
let mut files = Vec::new();
|
|
||||||
|
|
||||||
for entry in glob(pattern.to_str().unwrap()).unwrap() {
|
|
||||||
match entry {
|
|
||||||
Ok(path) => {
|
|
||||||
files.push(path);
|
|
||||||
}
|
|
||||||
Err(e) => eprintln!("cargo:warning=error={}", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
files
|
|
||||||
}
|
|
||||||
|
|
||||||
fn macos_link_search_path() -> Option<String> {
|
|
||||||
let output = Command::new("clang")
|
|
||||||
.arg("--print-search-dirs")
|
|
||||||
.output()
|
|
||||||
.ok()?;
|
|
||||||
if !output.status.success() {
|
|
||||||
println!(
|
|
||||||
"failed to run 'clang --print-search-dirs', continuing without a link search path"
|
|
||||||
);
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
|
||||||
for line in stdout.lines() {
|
|
||||||
if line.contains("libraries: =") {
|
|
||||||
let path = line.split('=').nth(1)?;
|
|
||||||
return Some(format!("{}/lib/darwin", path));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("failed to determine link search path, continuing without it");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn validate_android_ndk(ndk_path: &str) -> Result<(), String> {
|
|
||||||
let ndk_path = Path::new(ndk_path);
|
|
||||||
|
|
||||||
if !ndk_path.exists() {
|
|
||||||
return Err(format!(
|
|
||||||
"Android NDK path does not exist: {}",
|
|
||||||
ndk_path.display()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let toolchain_file = ndk_path.join("build/cmake/android.toolchain.cmake");
|
|
||||||
if !toolchain_file.exists() {
|
|
||||||
return Err(format!(
|
|
||||||
"Android NDK toolchain file not found: {}\n\
|
|
||||||
This indicates an incomplete NDK installation.",
|
|
||||||
toolchain_file.display()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_hidden(e: &DirEntry) -> bool {
|
|
||||||
e.file_name()
|
|
||||||
.to_str()
|
|
||||||
.map(|s| s.starts_with('.'))
|
|
||||||
.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
|
||||||
|
|
||||||
let (target_os, target_triple) =
|
|
||||||
parse_target_os().unwrap_or_else(|t| panic!("Failed to parse target os {t}"));
|
|
||||||
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
|
||||||
|
|
||||||
let target_dir = get_cargo_target_dir().unwrap();
|
|
||||||
let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("Failed to get CARGO_MANIFEST_DIR");
|
|
||||||
let llama_src = Path::new(&manifest_dir).join("llama.cpp");
|
|
||||||
let build_shared_libs = cfg!(feature = "dynamic-link");
|
|
||||||
|
|
||||||
let build_shared_libs = std::env::var("LLAMA_BUILD_SHARED_LIBS")
|
|
||||||
.map(|v| v == "1")
|
|
||||||
.unwrap_or(build_shared_libs);
|
|
||||||
let profile = env::var("LLAMA_LIB_PROFILE").unwrap_or("Release".to_string());
|
|
||||||
let static_crt = env::var("LLAMA_STATIC_CRT")
|
|
||||||
.map(|v| v == "1")
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
println!("cargo:rerun-if-env-changed=LLAMA_LIB_PROFILE");
|
|
||||||
println!("cargo:rerun-if-env-changed=LLAMA_BUILD_SHARED_LIBS");
|
|
||||||
println!("cargo:rerun-if-env-changed=LLAMA_STATIC_CRT");
|
|
||||||
|
|
||||||
debug_log!("TARGET: {}", target_triple);
|
|
||||||
debug_log!("CARGO_MANIFEST_DIR: {}", manifest_dir);
|
|
||||||
debug_log!("TARGET_DIR: {}", target_dir.display());
|
|
||||||
debug_log!("OUT_DIR: {}", out_dir.display());
|
|
||||||
debug_log!("BUILD_SHARED: {}", build_shared_libs);
|
|
||||||
|
|
||||||
// Make sure that changes to the llama.cpp project trigger a rebuild.
|
|
||||||
let rebuild_on_children_of = [
|
|
||||||
llama_src.join("src"),
|
|
||||||
llama_src.join("ggml/src"),
|
|
||||||
llama_src.join("common"),
|
|
||||||
];
|
|
||||||
for entry in walkdir::WalkDir::new(&llama_src)
|
|
||||||
.into_iter()
|
|
||||||
.filter_entry(|e| !is_hidden(e))
|
|
||||||
{
|
|
||||||
let entry = entry.expect("Failed to obtain entry");
|
|
||||||
let rebuild = entry
|
|
||||||
.file_name()
|
|
||||||
.to_str()
|
|
||||||
.map(|f| f.starts_with("CMake"))
|
|
||||||
.unwrap_or_default()
|
|
||||||
|| rebuild_on_children_of
|
|
||||||
.iter()
|
|
||||||
.any(|src_folder| entry.path().starts_with(src_folder));
|
|
||||||
if rebuild {
|
|
||||||
println!("cargo:rerun-if-changed={}", entry.path().display());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Speed up build
|
|
||||||
env::set_var(
|
|
||||||
"CMAKE_BUILD_PARALLEL_LEVEL",
|
|
||||||
std::thread::available_parallelism()
|
|
||||||
.unwrap()
|
|
||||||
.get()
|
|
||||||
.to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Bindings
|
|
||||||
let mut bindings_builder = bindgen::Builder::default()
|
|
||||||
.header("wrapper.h")
|
|
||||||
.clang_arg(format!("-I{}", llama_src.join("include").display()))
|
|
||||||
.clang_arg(format!("-I{}", llama_src.join("ggml/include").display()))
|
|
||||||
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
|
|
||||||
.derive_partialeq(true)
|
|
||||||
.allowlist_function("ggml_.*")
|
|
||||||
.allowlist_type("ggml_.*")
|
|
||||||
.allowlist_function("llama_.*")
|
|
||||||
.allowlist_type("llama_.*")
|
|
||||||
.prepend_enum_name(false);
|
|
||||||
|
|
||||||
// Configure mtmd feature if enabled
|
|
||||||
if cfg!(feature = "mtmd") {
|
|
||||||
bindings_builder = bindings_builder
|
|
||||||
.header("wrapper_mtmd.h")
|
|
||||||
.allowlist_function("mtmd_.*")
|
|
||||||
.allowlist_type("mtmd_.*");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configure Android-specific bindgen settings
|
|
||||||
if matches!(target_os, TargetOs::Android) {
|
|
||||||
// Detect Android NDK from environment variables
|
|
||||||
let android_ndk = env::var("ANDROID_NDK")
|
|
||||||
.or_else(|_| env::var("ANDROID_NDK_ROOT"))
|
|
||||||
.or_else(|_| env::var("NDK_ROOT"))
|
|
||||||
.or_else(|_| env::var("CARGO_NDK_ANDROID_NDK"))
|
|
||||||
.or_else(|_| {
|
|
||||||
// Try to auto-detect NDK from Android SDK
|
|
||||||
if let Some(home) = env::home_dir() {
|
|
||||||
let android_home = env::var("ANDROID_HOME")
|
|
||||||
.or_else(|_| env::var("ANDROID_SDK_ROOT"))
|
|
||||||
.unwrap_or_else(|_| format!("{}/Android/Sdk", home.display()));
|
|
||||||
|
|
||||||
let ndk_dir = format!("{}/ndk", android_home);
|
|
||||||
if let Ok(entries) = std::fs::read_dir(&ndk_dir) {
|
|
||||||
let mut versions: Vec<_> = entries
|
|
||||||
.filter_map(|e| e.ok())
|
|
||||||
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
|
|
||||||
.filter_map(|e| e.file_name().to_str().map(|s| s.to_string()))
|
|
||||||
.collect();
|
|
||||||
versions.sort();
|
|
||||||
if let Some(latest) = versions.last() {
|
|
||||||
return Ok(format!("{}/{}", ndk_dir, latest));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(env::VarError::NotPresent)
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|_| {
|
|
||||||
panic!(
|
|
||||||
"Android NDK not found. Please set one of: ANDROID_NDK, NDK_ROOT, ANDROID_NDK_ROOT\n\
|
|
||||||
Current target: {}\n\
|
|
||||||
Download from: https://developer.android.com/ndk/downloads",
|
|
||||||
target_triple
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Get Android API level
|
|
||||||
let android_api = env::var("ANDROID_API_LEVEL")
|
|
||||||
.or_else(|_| env::var("ANDROID_PLATFORM").map(|p| p.replace("android-", "")))
|
|
||||||
.or_else(|_| env::var("CARGO_NDK_ANDROID_PLATFORM").map(|p| p.replace("android-", "")))
|
|
||||||
.unwrap_or_else(|_| "28".to_string());
|
|
||||||
|
|
||||||
// Determine host platform
|
|
||||||
let host_tag = if cfg!(target_os = "macos") {
|
|
||||||
"darwin-x86_64"
|
|
||||||
} else if cfg!(target_os = "linux") {
|
|
||||||
"linux-x86_64"
|
|
||||||
} else if cfg!(target_os = "windows") {
|
|
||||||
"windows-x86_64"
|
|
||||||
} else {
|
|
||||||
panic!("Unsupported host platform for Android NDK");
|
|
||||||
};
|
|
||||||
|
|
||||||
// Map Rust target to Android architecture
|
|
||||||
let android_target_prefix = if target_triple.contains("aarch64") {
|
|
||||||
"aarch64-linux-android"
|
|
||||||
} else if target_triple.contains("armv7") {
|
|
||||||
"arm-linux-androideabi"
|
|
||||||
} else if target_triple.contains("x86_64") {
|
|
||||||
"x86_64-linux-android"
|
|
||||||
} else if target_triple.contains("i686") {
|
|
||||||
"i686-linux-android"
|
|
||||||
} else {
|
|
||||||
panic!("Unsupported Android target: {}", target_triple);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Setup Android toolchain paths
|
|
||||||
let toolchain_path = format!("{}/toolchains/llvm/prebuilt/{}", android_ndk, host_tag);
|
|
||||||
let sysroot = format!("{}/sysroot", toolchain_path);
|
|
||||||
|
|
||||||
// Validate toolchain existence
|
|
||||||
if !std::path::Path::new(&toolchain_path).exists() {
|
|
||||||
panic!(
|
|
||||||
"Android NDK toolchain not found at: {}\n\
|
|
||||||
Please ensure you have the correct Android NDK for your platform.",
|
|
||||||
toolchain_path
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find clang builtin includes
|
|
||||||
let clang_builtin_includes = {
|
|
||||||
let clang_lib_path = format!("{}/lib/clang", toolchain_path);
|
|
||||||
std::fs::read_dir(&clang_lib_path).ok().and_then(|entries| {
|
|
||||||
entries
|
|
||||||
.filter_map(|e| e.ok())
|
|
||||||
.find(|entry| {
|
|
||||||
entry.file_type().map(|t| t.is_dir()).unwrap_or(false)
|
|
||||||
&& entry
|
|
||||||
.file_name()
|
|
||||||
.to_str()
|
|
||||||
.map(|name| name.chars().next().unwrap_or('0').is_ascii_digit())
|
|
||||||
.unwrap_or(false)
|
|
||||||
})
|
|
||||||
.and_then(|entry| {
|
|
||||||
let include_path =
|
|
||||||
format!("{}/{}/include", clang_lib_path, entry.file_name().to_str()?);
|
|
||||||
if std::path::Path::new(&include_path).exists() {
|
|
||||||
Some(include_path)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
};
|
|
||||||
|
|
||||||
// Configure bindgen for Android
|
|
||||||
bindings_builder = bindings_builder
|
|
||||||
.clang_arg(format!("--sysroot={}", sysroot))
|
|
||||||
.clang_arg(format!("-D__ANDROID_API__={}", android_api))
|
|
||||||
.clang_arg("-D__ANDROID__");
|
|
||||||
|
|
||||||
// Add include paths in correct order
|
|
||||||
if let Some(ref builtin_includes) = clang_builtin_includes {
|
|
||||||
bindings_builder = bindings_builder
|
|
||||||
.clang_arg("-isystem")
|
|
||||||
.clang_arg(builtin_includes);
|
|
||||||
}
|
|
||||||
|
|
||||||
bindings_builder = bindings_builder
|
|
||||||
.clang_arg("-isystem")
|
|
||||||
.clang_arg(format!("{}/usr/include/{}", sysroot, android_target_prefix))
|
|
||||||
.clang_arg("-isystem")
|
|
||||||
.clang_arg(format!("{}/usr/include", sysroot))
|
|
||||||
.clang_arg("-include")
|
|
||||||
.clang_arg("stdbool.h")
|
|
||||||
.clang_arg("-include")
|
|
||||||
.clang_arg("stdint.h");
|
|
||||||
|
|
||||||
// Set additional clang args for cargo ndk compatibility
|
|
||||||
if env::var("CARGO_SUBCOMMAND").as_deref() == Ok("ndk") {
|
|
||||||
std::env::set_var(
|
|
||||||
"BINDGEN_EXTRA_CLANG_ARGS",
|
|
||||||
format!("--target={}", target_triple),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fix bindgen header discovery on Windows MSVC
|
|
||||||
// Use cc crate to discover MSVC include paths by compiling a dummy file
|
|
||||||
if matches!(target_os, TargetOs::Windows(WindowsVariant::Msvc)) {
|
|
||||||
// Create a minimal dummy C file to extract compiler flags
|
|
||||||
let out_dir = env::var("OUT_DIR").unwrap();
|
|
||||||
let dummy_c = Path::new(&out_dir).join("dummy.c");
|
|
||||||
std::fs::write(&dummy_c, "int main() { return 0; }").unwrap();
|
|
||||||
|
|
||||||
// Use cc crate to get compiler with proper environment setup
|
|
||||||
let mut build = cc::Build::new();
|
|
||||||
build.file(&dummy_c);
|
|
||||||
|
|
||||||
// Get the actual compiler command cc would use
|
|
||||||
let compiler = build.try_get_compiler().unwrap();
|
|
||||||
|
|
||||||
// Extract include paths by checking compiler's environment
|
|
||||||
// cc crate sets up MSVC environment internally
|
|
||||||
let env_include = compiler
|
|
||||||
.env()
|
|
||||||
.iter()
|
|
||||||
.find(|(k, _)| k.eq_ignore_ascii_case("INCLUDE"))
|
|
||||||
.map(|(_, v)| v);
|
|
||||||
|
|
||||||
if let Some(include_paths) = env_include {
|
|
||||||
for include_path in include_paths
|
|
||||||
.to_string_lossy()
|
|
||||||
.split(';')
|
|
||||||
.filter(|s| !s.is_empty())
|
|
||||||
{
|
|
||||||
bindings_builder = bindings_builder
|
|
||||||
.clang_arg("-isystem")
|
|
||||||
.clang_arg(include_path);
|
|
||||||
debug_log!("Added MSVC include path: {}", include_path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add MSVC compatibility flags
|
|
||||||
bindings_builder = bindings_builder
|
|
||||||
.clang_arg(format!("--target={}", target_triple))
|
|
||||||
.clang_arg("-fms-compatibility")
|
|
||||||
.clang_arg("-fms-extensions");
|
|
||||||
|
|
||||||
debug_log!(
|
|
||||||
"Configured bindgen with MSVC toolchain for target: {}",
|
|
||||||
target_triple
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let bindings = bindings_builder
|
|
||||||
.generate()
|
|
||||||
.expect("Failed to generate bindings");
|
|
||||||
|
|
||||||
// Write the generated bindings to an output file
|
|
||||||
let bindings_path = out_dir.join("bindings.rs");
|
|
||||||
bindings
|
|
||||||
.write_to_file(bindings_path)
|
|
||||||
.expect("Failed to write bindings");
|
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=wrapper.h");
|
|
||||||
println!("cargo:rerun-if-changed=wrapper_mtmd.h");
|
|
||||||
|
|
||||||
debug_log!("Bindings Created");
|
|
||||||
|
|
||||||
// Build with Cmake
|
|
||||||
|
|
||||||
let mut config = Config::new(&llama_src);
|
|
||||||
|
|
||||||
// Would require extra source files to pointlessly
|
|
||||||
// be included in what's uploaded to and downloaded from
|
|
||||||
// crates.io, so deactivating these instead
|
|
||||||
config.define("LLAMA_BUILD_TESTS", "OFF");
|
|
||||||
config.define("LLAMA_BUILD_EXAMPLES", "OFF");
|
|
||||||
config.define("LLAMA_BUILD_SERVER", "OFF");
|
|
||||||
config.define("LLAMA_BUILD_TOOLS", "OFF");
|
|
||||||
config.define("LLAMA_CURL", "OFF");
|
|
||||||
|
|
||||||
if cfg!(feature = "mtmd") {
|
|
||||||
config.define("LLAMA_BUILD_COMMON", "ON");
|
|
||||||
// mtmd support in llama-cpp is within the tools directory
|
|
||||||
config.define("LLAMA_BUILD_TOOLS", "ON");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pass CMAKE_ environment variables down to CMake
|
|
||||||
for (key, value) in env::vars() {
|
|
||||||
if key.starts_with("CMAKE_") {
|
|
||||||
config.define(&key, &value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract the target-cpu config value, if specified
|
|
||||||
let target_cpu = std::env::var("CARGO_ENCODED_RUSTFLAGS")
|
|
||||||
.ok()
|
|
||||||
.and_then(|rustflags| {
|
|
||||||
rustflags
|
|
||||||
.split('\x1f')
|
|
||||||
.find(|f| f.contains("target-cpu="))
|
|
||||||
.and_then(|f| f.split("target-cpu=").nth(1))
|
|
||||||
.map(|s| s.to_string())
|
|
||||||
});
|
|
||||||
|
|
||||||
if target_cpu == Some("native".into()) {
|
|
||||||
debug_log!("Detected target-cpu=native, compiling with GGML_NATIVE");
|
|
||||||
config.define("GGML_NATIVE", "ON");
|
|
||||||
}
|
|
||||||
// if native isn't specified, enable specific features for ggml instead
|
|
||||||
else {
|
|
||||||
// rust code isn't using `target-cpu=native`, so llama.cpp shouldn't use GGML_NATIVE either
|
|
||||||
config.define("GGML_NATIVE", "OFF");
|
|
||||||
|
|
||||||
// if `target-cpu` is set set, also set -march for llama.cpp to the same value
|
|
||||||
if let Some(ref cpu) = target_cpu {
|
|
||||||
debug_log!("Setting baseline architecture: -march={}", cpu);
|
|
||||||
config.cflag(&format!("-march={}", cpu));
|
|
||||||
config.cxxflag(&format!("-march={}", cpu));
|
|
||||||
}
|
|
||||||
|
|
||||||
// I expect this env var to always be present
|
|
||||||
let features = std::env::var("CARGO_CFG_TARGET_FEATURE")
|
|
||||||
.expect("Env var CARGO_CFG_TARGET_FEATURE not found.");
|
|
||||||
debug_log!("Compiling with target features: {}", features);
|
|
||||||
|
|
||||||
// list of rust target_features here:
|
|
||||||
// https://doc.rust-lang.org/reference/attributes/codegen.html#the-target_feature-attribute
|
|
||||||
// GGML config flags have been found by looking at:
|
|
||||||
// llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
|
|
||||||
for feature in features.split(',') {
|
|
||||||
match feature {
|
|
||||||
"avx" => {
|
|
||||||
config.define("GGML_AVX", "ON");
|
|
||||||
}
|
|
||||||
"avx2" => {
|
|
||||||
config.define("GGML_AVX2", "ON");
|
|
||||||
}
|
|
||||||
"avx512bf16" => {
|
|
||||||
config.define("GGML_AVX512_BF16", "ON");
|
|
||||||
}
|
|
||||||
"avx512vbmi" => {
|
|
||||||
config.define("GGML_AVX512_VBMI", "ON");
|
|
||||||
}
|
|
||||||
"avx512vnni" => {
|
|
||||||
config.define("GGML_AVX512_VNNI", "ON");
|
|
||||||
}
|
|
||||||
"avxvnni" => {
|
|
||||||
config.define("GGML_AVX_VNNI", "ON");
|
|
||||||
}
|
|
||||||
"bmi2" => {
|
|
||||||
config.define("GGML_BMI2", "ON");
|
|
||||||
}
|
|
||||||
"f16c" => {
|
|
||||||
config.define("GGML_F16C", "ON");
|
|
||||||
}
|
|
||||||
"fma" => {
|
|
||||||
config.define("GGML_FMA", "ON");
|
|
||||||
}
|
|
||||||
"sse4.2" => {
|
|
||||||
config.define("GGML_SSE42", "ON");
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
debug_log!(
|
|
||||||
"Unrecognized cpu feature: '{}' - skipping GGML config for it.",
|
|
||||||
feature
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config.define(
|
|
||||||
"BUILD_SHARED_LIBS",
|
|
||||||
if build_shared_libs { "ON" } else { "OFF" },
|
|
||||||
);
|
|
||||||
|
|
||||||
if matches!(target_os, TargetOs::Apple(_)) {
|
|
||||||
config.define("GGML_BLAS", "OFF");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (matches!(target_os, TargetOs::Windows(WindowsVariant::Msvc))
|
|
||||||
&& matches!(
|
|
||||||
profile.as_str(),
|
|
||||||
"Release" | "RelWithDebInfo" | "MinSizeRel"
|
|
||||||
))
|
|
||||||
{
|
|
||||||
// Debug Rust builds under MSVC turn off optimization even though we're ideally building the release profile of llama.cpp.
|
|
||||||
// Looks like an upstream bug:
|
|
||||||
// https://github.com/rust-lang/cmake-rs/issues/240
|
|
||||||
// For now explicitly reinject the optimization flags that a CMake Release build is expected to have on in this scenario.
|
|
||||||
// This fixes CPU inference performance when part of a Rust debug build.
|
|
||||||
for flag in &["/O2", "/DNDEBUG", "/Ob2"] {
|
|
||||||
config.cflag(flag);
|
|
||||||
config.cxxflag(flag);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config.static_crt(static_crt);
|
|
||||||
|
|
||||||
if matches!(target_os, TargetOs::Android) {
|
|
||||||
// Android NDK Build Configuration
|
|
||||||
let android_ndk = env::var("ANDROID_NDK")
|
|
||||||
.or_else(|_| env::var("NDK_ROOT"))
|
|
||||||
.or_else(|_| env::var("ANDROID_NDK_ROOT"))
|
|
||||||
.unwrap_or_else(|_| {
|
|
||||||
panic!(
|
|
||||||
"Android NDK not found. Please set one of: ANDROID_NDK, NDK_ROOT, ANDROID_NDK_ROOT\n\
|
|
||||||
Download from: https://developer.android.com/ndk/downloads"
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Validate NDK installation
|
|
||||||
if let Err(error) = validate_android_ndk(&android_ndk) {
|
|
||||||
panic!("Android NDK validation failed: {}", error);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rerun build script if NDK environment variables change
|
|
||||||
println!("cargo:rerun-if-env-changed=ANDROID_NDK");
|
|
||||||
println!("cargo:rerun-if-env-changed=NDK_ROOT");
|
|
||||||
println!("cargo:rerun-if-env-changed=ANDROID_NDK_ROOT");
|
|
||||||
|
|
||||||
// Set CMake toolchain file for Android
|
|
||||||
let toolchain_file = format!("{}/build/cmake/android.toolchain.cmake", android_ndk);
|
|
||||||
config.define("CMAKE_TOOLCHAIN_FILE", &toolchain_file);
|
|
||||||
|
|
||||||
// Configure Android platform (API level)
|
|
||||||
let android_platform = env::var("ANDROID_PLATFORM").unwrap_or_else(|_| {
|
|
||||||
env::var("ANDROID_API_LEVEL")
|
|
||||||
.map(|level| format!("android-{}", level))
|
|
||||||
.unwrap_or_else(|_| "android-28".to_string())
|
|
||||||
});
|
|
||||||
|
|
||||||
println!("cargo:rerun-if-env-changed=ANDROID_PLATFORM");
|
|
||||||
println!("cargo:rerun-if-env-changed=ANDROID_API_LEVEL");
|
|
||||||
config.define("ANDROID_PLATFORM", &android_platform);
|
|
||||||
|
|
||||||
// Map Rust target to Android ABI
|
|
||||||
let android_abi = if target_triple.contains("aarch64") {
|
|
||||||
"arm64-v8a"
|
|
||||||
} else if target_triple.contains("armv7") {
|
|
||||||
"armeabi-v7a"
|
|
||||||
} else if target_triple.contains("x86_64") {
|
|
||||||
"x86_64"
|
|
||||||
} else if target_triple.contains("i686") {
|
|
||||||
"x86"
|
|
||||||
} else {
|
|
||||||
panic!(
|
|
||||||
"Unsupported Android target: {}\n\
|
|
||||||
Supported targets: aarch64-linux-android, armv7-linux-androideabi, i686-linux-android, x86_64-linux-android",
|
|
||||||
target_triple
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
config.define("ANDROID_ABI", android_abi);
|
|
||||||
|
|
||||||
// Configure architecture-specific compiler flags
|
|
||||||
match android_abi {
|
|
||||||
"arm64-v8a" => {
|
|
||||||
config.cflag("-march=armv8-a");
|
|
||||||
config.cxxflag("-march=armv8-a");
|
|
||||||
}
|
|
||||||
"armeabi-v7a" => {
|
|
||||||
config.cflag("-march=armv7-a");
|
|
||||||
config.cxxflag("-march=armv7-a");
|
|
||||||
config.cflag("-mfpu=neon");
|
|
||||||
config.cxxflag("-mfpu=neon");
|
|
||||||
config.cflag("-mthumb");
|
|
||||||
config.cxxflag("-mthumb");
|
|
||||||
}
|
|
||||||
"x86_64" => {
|
|
||||||
config.cflag("-march=x86-64");
|
|
||||||
config.cxxflag("-march=x86-64");
|
|
||||||
}
|
|
||||||
"x86" => {
|
|
||||||
config.cflag("-march=i686");
|
|
||||||
config.cxxflag("-march=i686");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Android-specific CMake configurations
|
|
||||||
config.define("GGML_LLAMAFILE", "OFF");
|
|
||||||
|
|
||||||
// Link Android system libraries
|
|
||||||
println!("cargo:rustc-link-lib=log");
|
|
||||||
println!("cargo:rustc-link-lib=android");
|
|
||||||
}
|
|
||||||
|
|
||||||
if matches!(target_os, TargetOs::Linux)
|
|
||||||
&& target_triple.contains("aarch64")
|
|
||||||
&& target_cpu != Some("native".into())
|
|
||||||
{
|
|
||||||
// If the target-cpu is not specified as native, we take off the native ARM64 support.
|
|
||||||
// It is useful in docker environments where the native feature is not enabled.
|
|
||||||
config.define("GGML_NATIVE", "OFF");
|
|
||||||
config.define("GGML_CPU_ARM_ARCH", "armv8-a");
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg!(feature = "vulkan") {
|
|
||||||
config.define("GGML_VULKAN", "ON");
|
|
||||||
match target_os {
|
|
||||||
TargetOs::Windows(_) => {
|
|
||||||
let vulkan_path = env::var("VULKAN_SDK").expect(
|
|
||||||
"Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set",
|
|
||||||
);
|
|
||||||
let vulkan_lib_path = Path::new(&vulkan_path).join("Lib");
|
|
||||||
println!("cargo:rustc-link-search={}", vulkan_lib_path.display());
|
|
||||||
println!("cargo:rustc-link-lib=vulkan-1");
|
|
||||||
|
|
||||||
// workaround for this error: "FileTracker : error FTK1011: could not create the new file tracking log file"
|
|
||||||
// it has to do with MSBuild FileTracker not respecting the path
|
|
||||||
// limit configuration set in the windows registry.
|
|
||||||
// I'm not sure why that's a thing, but this makes my builds work.
|
|
||||||
// (crates that depend on llama-cpp-rs w/ vulkan easily exceed the default PATH_MAX on windows)
|
|
||||||
env::set_var("TrackFileAccess", "false");
|
|
||||||
// since we disabled TrackFileAccess, we can now run into problems with parallel
|
|
||||||
// access to pdb files. /FS solves this.
|
|
||||||
config.cflag("/FS");
|
|
||||||
config.cxxflag("/FS");
|
|
||||||
}
|
|
||||||
TargetOs::Linux => {
|
|
||||||
// If we are not using system provided vulkan SDK, add vulkan libs for linking
|
|
||||||
if let Ok(vulkan_path) = env::var("VULKAN_SDK") {
|
|
||||||
let vulkan_lib_path = Path::new(&vulkan_path).join("lib");
|
|
||||||
println!("cargo:rustc-link-search={}", vulkan_lib_path.display());
|
|
||||||
}
|
|
||||||
println!("cargo:rustc-link-lib=vulkan");
|
|
||||||
}
|
|
||||||
_ => (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg!(feature = "cuda") {
|
|
||||||
config.define("GGML_CUDA", "ON");
|
|
||||||
|
|
||||||
if cfg!(feature = "cuda-no-vmm") {
|
|
||||||
config.define("GGML_CUDA_NO_VMM", "ON");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Android doesn't have OpenMP support AFAICT and openmp is a default feature. Do this here
|
|
||||||
// rather than modifying the defaults in Cargo.toml just in case someone enables the OpenMP feature
|
|
||||||
// and tries to build for Android anyway.
|
|
||||||
if cfg!(feature = "openmp") && !matches!(target_os, TargetOs::Android) {
|
|
||||||
config.define("GGML_OPENMP", "ON");
|
|
||||||
} else {
|
|
||||||
config.define("GGML_OPENMP", "OFF");
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg!(feature = "system-ggml") {
|
|
||||||
config.define("LLAMA_USE_SYSTEM_GGML", "ON");
|
|
||||||
}
|
|
||||||
|
|
||||||
// General
|
|
||||||
config
|
|
||||||
.profile(&profile)
|
|
||||||
.very_verbose(std::env::var("CMAKE_VERBOSE").is_ok()) // Not verbose by default
|
|
||||||
.always_configure(false);
|
|
||||||
|
|
||||||
let build_dir = config.build();
|
|
||||||
|
|
||||||
// Search paths
|
|
||||||
println!("cargo:rustc-link-search={}", out_dir.join("lib").display());
|
|
||||||
println!(
|
|
||||||
"cargo:rustc-link-search={}",
|
|
||||||
out_dir.join("lib64").display()
|
|
||||||
);
|
|
||||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
|
||||||
|
|
||||||
if cfg!(feature = "system-ggml") {
|
|
||||||
// Extract library directory from CMake's found GGML package
|
|
||||||
let cmake_cache = build_dir.join("build").join("CMakeCache.txt");
|
|
||||||
if let Ok(cache_contents) = std::fs::read_to_string(&cmake_cache) {
|
|
||||||
let mut ggml_lib_dirs = std::collections::HashSet::new();
|
|
||||||
|
|
||||||
// Parse CMakeCache.txt to find where GGML libraries were found
|
|
||||||
for line in cache_contents.lines() {
|
|
||||||
if line.starts_with("GGML_LIBRARY:")
|
|
||||||
|| line.starts_with("GGML_BASE_LIBRARY:")
|
|
||||||
|| line.starts_with("GGML_CPU_LIBRARY:")
|
|
||||||
{
|
|
||||||
if let Some(lib_path) = line.split('=').nth(1) {
|
|
||||||
if let Some(parent) = Path::new(lib_path).parent() {
|
|
||||||
ggml_lib_dirs.insert(parent.to_path_buf());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add each unique library directory to the search path
|
|
||||||
for lib_dir in ggml_lib_dirs {
|
|
||||||
println!("cargo:rustc-link-search=native={}", lib_dir.display());
|
|
||||||
debug_log!("Added system GGML library path: {}", lib_dir.display());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg!(feature = "cuda") && !build_shared_libs {
|
|
||||||
// Re-run build script if CUDA_PATH environment variable changes
|
|
||||||
println!("cargo:rerun-if-env-changed=CUDA_PATH");
|
|
||||||
|
|
||||||
// Add CUDA library directories to the linker search path
|
|
||||||
for lib_dir in find_cuda_helper::find_cuda_lib_dirs() {
|
|
||||||
println!("cargo:rustc-link-search=native={}", lib_dir.display());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Platform-specific linking
|
|
||||||
if cfg!(target_os = "windows") {
|
|
||||||
// ✅ On Windows, use dynamic linking.
|
|
||||||
// Static linking is problematic because NVIDIA does not provide culibos.lib,
|
|
||||||
// and static CUDA libraries (like cublas_static.lib) are usually not shipped.
|
|
||||||
|
|
||||||
println!("cargo:rustc-link-lib=cudart"); // Links to cudart64_*.dll
|
|
||||||
println!("cargo:rustc-link-lib=cublas"); // Links to cublas64_*.dll
|
|
||||||
println!("cargo:rustc-link-lib=cublasLt"); // Links to cublasLt64_*.dll
|
|
||||||
|
|
||||||
// Link to CUDA driver API (nvcuda.dll via cuda.lib)
|
|
||||||
if !cfg!(feature = "cuda-no-vmm") {
|
|
||||||
println!("cargo:rustc-link-lib=cuda");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// ✅ On non-Windows platforms (e.g., Linux), static linking is preferred and supported.
|
|
||||||
// Static libraries like cudart_static and cublas_static depend on culibos.
|
|
||||||
|
|
||||||
println!("cargo:rustc-link-lib=static=cudart_static");
|
|
||||||
println!("cargo:rustc-link-lib=static=cublas_static");
|
|
||||||
println!("cargo:rustc-link-lib=static=cublasLt_static");
|
|
||||||
|
|
||||||
// Link to CUDA driver API (libcuda.so)
|
|
||||||
if !cfg!(feature = "cuda-no-vmm") {
|
|
||||||
println!("cargo:rustc-link-lib=cuda");
|
|
||||||
}
|
|
||||||
|
|
||||||
// culibos is required when statically linking cudart_static
|
|
||||||
println!("cargo:rustc-link-lib=static=culibos");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Link libraries
|
|
||||||
let llama_libs_kind = if build_shared_libs || cfg!(feature = "system-ggml") {
|
|
||||||
"dylib"
|
|
||||||
} else {
|
|
||||||
"static"
|
|
||||||
};
|
|
||||||
let llama_libs = extract_lib_names(&out_dir, build_shared_libs);
|
|
||||||
assert_ne!(llama_libs.len(), 0);
|
|
||||||
|
|
||||||
if cfg!(feature = "system-ggml") {
|
|
||||||
println!("cargo:rustc-link-lib={llama_libs_kind}=ggml");
|
|
||||||
println!("cargo:rustc-link-lib={llama_libs_kind}=ggml-base");
|
|
||||||
println!("cargo:rustc-link-lib={llama_libs_kind}=ggml-cpu");
|
|
||||||
}
|
|
||||||
for lib in llama_libs {
|
|
||||||
let link = format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib);
|
|
||||||
debug_log!("LINK {link}",);
|
|
||||||
println!("{link}",);
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenMP
|
|
||||||
if cfg!(feature = "openmp") && target_triple.contains("gnu") {
|
|
||||||
println!("cargo:rustc-link-lib=gomp");
|
|
||||||
}
|
|
||||||
|
|
||||||
match target_os {
|
|
||||||
TargetOs::Windows(WindowsVariant::Msvc) => {
|
|
||||||
println!("cargo:rustc-link-lib=advapi32");
|
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
println!("cargo:rustc-link-lib=dylib=msvcrtd");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TargetOs::Linux => {
|
|
||||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
|
||||||
}
|
|
||||||
TargetOs::Apple(variant) => {
|
|
||||||
println!("cargo:rustc-link-lib=framework=Foundation");
|
|
||||||
println!("cargo:rustc-link-lib=framework=Metal");
|
|
||||||
println!("cargo:rustc-link-lib=framework=MetalKit");
|
|
||||||
println!("cargo:rustc-link-lib=framework=Accelerate");
|
|
||||||
println!("cargo:rustc-link-lib=c++");
|
|
||||||
|
|
||||||
match variant {
|
|
||||||
AppleVariant::MacOS => {
|
|
||||||
// On (older) OSX we need to link against the clang runtime,
|
|
||||||
// which is hidden in some non-default path.
|
|
||||||
//
|
|
||||||
// More details at https://github.com/alexcrichton/curl-rust/issues/279.
|
|
||||||
if let Some(path) = macos_link_search_path() {
|
|
||||||
println!("cargo:rustc-link-lib=clang_rt.osx");
|
|
||||||
println!("cargo:rustc-link-search={}", path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AppleVariant::Other => (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => (),
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy DLLs to target
|
|
||||||
if build_shared_libs {
|
|
||||||
let libs_assets = extract_lib_assets(&out_dir);
|
|
||||||
for asset in libs_assets {
|
|
||||||
let asset_clone = asset.clone();
|
|
||||||
let filename = asset_clone.file_name().unwrap();
|
|
||||||
let filename = filename.to_str().unwrap();
|
|
||||||
let dst = target_dir.join(filename);
|
|
||||||
debug_log!("HARD LINK {} TO {}", asset.display(), dst.display());
|
|
||||||
if !dst.exists() {
|
|
||||||
std::fs::hard_link(asset.clone(), dst).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy DLLs to examples as well
|
|
||||||
if target_dir.join("examples").exists() {
|
|
||||||
let dst = target_dir.join("examples").join(filename);
|
|
||||||
debug_log!("HARD LINK {} TO {}", asset.display(), dst.display());
|
|
||||||
if !dst.exists() {
|
|
||||||
std::fs::hard_link(asset.clone(), dst).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy DLLs to target/profile/deps as well for tests
|
|
||||||
let dst = target_dir.join("deps").join(filename);
|
|
||||||
debug_log!("HARD LINK {} TO {}", asset.display(), dst.display());
|
|
||||||
if !dst.exists() {
|
|
||||||
std::fs::hard_link(asset.clone(), dst).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,309 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
|
||||||
project("llama.cpp" C CXX)
|
|
||||||
include(CheckIncludeFileCXX)
|
|
||||||
|
|
||||||
#set(CMAKE_WARN_DEPRECATED YES)
|
|
||||||
set(CMAKE_WARN_UNUSED_CLI YES)
|
|
||||||
|
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
||||||
|
|
||||||
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
|
||||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
|
||||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
|
|
||||||
|
|
||||||
# Add path to modules
|
|
||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
|
||||||
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
|
||||||
|
|
||||||
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
|
||||||
set(LLAMA_STANDALONE ON)
|
|
||||||
|
|
||||||
include(git-vars)
|
|
||||||
|
|
||||||
# configure project version
|
|
||||||
# TODO
|
|
||||||
else()
|
|
||||||
set(LLAMA_STANDALONE OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF)
|
|
||||||
|
|
||||||
option(LLAMA_WASM_MEM64 "llama: use 64-bit memory in WASM builds" ON)
|
|
||||||
|
|
||||||
if (EMSCRIPTEN)
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
|
||||||
|
|
||||||
# Use 64-bit memory to support backend_get_memory queries
|
|
||||||
# TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using
|
|
||||||
if (LLAMA_WASM_MEM64)
|
|
||||||
add_compile_options("-sMEMORY64=1")
|
|
||||||
add_link_options("-sMEMORY64=1")
|
|
||||||
endif()
|
|
||||||
add_link_options("-sALLOW_MEMORY_GROWTH=1")
|
|
||||||
|
|
||||||
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF)
|
|
||||||
option(LLAMA_BUILD_HTML "llama: build HTML file" ON)
|
|
||||||
if (LLAMA_BUILD_HTML)
|
|
||||||
set(CMAKE_EXECUTABLE_SUFFIX ".html")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
if (MINGW)
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
|
||||||
else()
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
option(BUILD_SHARED_LIBS "build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
|
|
||||||
|
|
||||||
if (WIN32)
|
|
||||||
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (MSVC)
|
|
||||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
|
|
||||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
|
|
||||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/bigobj>")
|
|
||||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_STANDALONE)
|
|
||||||
# enable parallel builds for msbuild
|
|
||||||
list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true)
|
|
||||||
list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
|
||||||
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
|
|
||||||
else()
|
|
||||||
set(LLAMA_TOOLS_INSTALL_DEFAULT ${LLAMA_STANDALONE})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
|
||||||
# option list
|
|
||||||
#
|
|
||||||
|
|
||||||
# debug
|
|
||||||
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
|
|
||||||
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
|
|
||||||
|
|
||||||
# build
|
|
||||||
option(LLAMA_FATAL_WARNINGS "llama: enable -Werror flag" OFF)
|
|
||||||
|
|
||||||
# sanitizers
|
|
||||||
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
|
|
||||||
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
|
|
||||||
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
|
|
||||||
|
|
||||||
# utils
|
|
||||||
option(LLAMA_BUILD_COMMON "llama: build common utils library" ${LLAMA_STANDALONE})
|
|
||||||
|
|
||||||
# extra artifacts
|
|
||||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
|
||||||
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
|
||||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
|
||||||
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
|
||||||
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
|
||||||
|
|
||||||
# 3rd party libs
|
|
||||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
|
|
||||||
option(LLAMA_HTTPLIB "llama: if libcurl is disabled, use httplib to download model from an URL" ON)
|
|
||||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF)
|
|
||||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
|
||||||
|
|
||||||
# Required for relocatable CMake package
|
|
||||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
|
|
||||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake)
|
|
||||||
|
|
||||||
if (NOT DEFINED LLAMA_BUILD_NUMBER)
|
|
||||||
set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER})
|
|
||||||
endif()
|
|
||||||
if (NOT DEFINED LLAMA_BUILD_COMMIT)
|
|
||||||
set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT})
|
|
||||||
endif()
|
|
||||||
set(LLAMA_INSTALL_VERSION 0.0.${LLAMA_BUILD_NUMBER})
|
|
||||||
|
|
||||||
# override ggml options
|
|
||||||
set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS})
|
|
||||||
set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS})
|
|
||||||
|
|
||||||
# change the default for these ggml options
|
|
||||||
if (NOT DEFINED GGML_LLAMAFILE)
|
|
||||||
set(GGML_LLAMAFILE_DEFAULT ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NOT DEFINED GGML_CUDA_GRAPHS)
|
|
||||||
set(GGML_CUDA_GRAPHS_DEFAULT ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# transition helpers
|
|
||||||
function (llama_option_depr TYPE OLD NEW)
|
|
||||||
if (${OLD})
|
|
||||||
message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n")
|
|
||||||
set(${NEW} ON PARENT_SCOPE)
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA)
|
|
||||||
llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA)
|
|
||||||
llama_option_depr(WARNING LLAMA_METAL GGML_METAL)
|
|
||||||
llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY)
|
|
||||||
llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE)
|
|
||||||
llama_option_depr(WARNING LLAMA_RPC GGML_RPC)
|
|
||||||
llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
|
|
||||||
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
|
|
||||||
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
|
|
||||||
|
|
||||||
if (NOT MSVC)
|
|
||||||
if (LLAMA_SANITIZE_THREAD)
|
|
||||||
message(STATUS "Using -fsanitize=thread")
|
|
||||||
|
|
||||||
add_compile_options(-fsanitize=thread)
|
|
||||||
link_libraries (-fsanitize=thread)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_SANITIZE_ADDRESS)
|
|
||||||
message(STATUS "Using -fsanitize=address")
|
|
||||||
|
|
||||||
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
|
||||||
link_libraries (-fsanitize=address)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_SANITIZE_UNDEFINED)
|
|
||||||
message(STATUS "Using -fsanitize=undefined")
|
|
||||||
|
|
||||||
add_compile_options(-fsanitize=undefined)
|
|
||||||
link_libraries (-fsanitize=undefined)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include("cmake/license.cmake")
|
|
||||||
license_add_file("llama.cpp" "LICENSE")
|
|
||||||
|
|
||||||
#
|
|
||||||
# 3rd-party
|
|
||||||
#
|
|
||||||
|
|
||||||
if (LLAMA_USE_SYSTEM_GGML)
|
|
||||||
message(STATUS "Using system-provided libggml, skipping ggml build")
|
|
||||||
find_package(ggml REQUIRED)
|
|
||||||
add_library(ggml ALIAS ggml::ggml)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
|
|
||||||
set(GGML_BUILD_NUMBER ${LLAMA_BUILD_NUMBER})
|
|
||||||
set(GGML_BUILD_COMMIT ${LLAMA_BUILD_COMMIT})
|
|
||||||
add_subdirectory(ggml)
|
|
||||||
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
|
||||||
# build the library
|
|
||||||
#
|
|
||||||
|
|
||||||
add_subdirectory(src)
|
|
||||||
|
|
||||||
#
|
|
||||||
# utils, programs, examples and tests
|
|
||||||
#
|
|
||||||
|
|
||||||
if (NOT LLAMA_BUILD_COMMON)
|
|
||||||
message(STATUS "LLAMA_BUILD_COMMON is OFF, disabling LLAMA_CURL")
|
|
||||||
set(LLAMA_CURL OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON)
|
|
||||||
add_subdirectory(common)
|
|
||||||
if (LLAMA_HTTPLIB)
|
|
||||||
add_subdirectory(vendor/cpp-httplib)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
|
||||||
include(CTest)
|
|
||||||
add_subdirectory(tests)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES)
|
|
||||||
add_subdirectory(examples)
|
|
||||||
add_subdirectory(pocs)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS)
|
|
||||||
add_subdirectory(tools)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Automatically add all files from the 'licenses' directory
|
|
||||||
file(GLOB EXTRA_LICENSES "${CMAKE_SOURCE_DIR}/licenses/LICENSE-*")
|
|
||||||
|
|
||||||
foreach(FILE_PATH ${EXTRA_LICENSES})
|
|
||||||
get_filename_component(FILE_NAME "${FILE_PATH}" NAME)
|
|
||||||
string(REGEX REPLACE "^LICENSE-" "" NAME "${FILE_NAME}")
|
|
||||||
license_add_file("${NAME}" "${FILE_PATH}")
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON)
|
|
||||||
license_generate(common)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
|
||||||
# install
|
|
||||||
#
|
|
||||||
|
|
||||||
include(GNUInstallDirs)
|
|
||||||
include(CMakePackageConfigHelpers)
|
|
||||||
|
|
||||||
set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
|
||||||
set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
|
||||||
set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
|
||||||
|
|
||||||
set(LLAMA_PUBLIC_HEADERS
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/include/llama-cpp.h)
|
|
||||||
|
|
||||||
set_target_properties(llama
|
|
||||||
PROPERTIES
|
|
||||||
PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
|
|
||||||
|
|
||||||
install(TARGETS llama LIBRARY PUBLIC_HEADER)
|
|
||||||
|
|
||||||
configure_package_config_file(
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/llama-config.cmake.in
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake
|
|
||||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama
|
|
||||||
PATH_VARS LLAMA_INCLUDE_INSTALL_DIR
|
|
||||||
LLAMA_LIB_INSTALL_DIR
|
|
||||||
LLAMA_BIN_INSTALL_DIR )
|
|
||||||
|
|
||||||
write_basic_package_version_file(
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake
|
|
||||||
VERSION ${LLAMA_INSTALL_VERSION}
|
|
||||||
COMPATIBILITY SameMajorVersion)
|
|
||||||
|
|
||||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake
|
|
||||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama)
|
|
||||||
|
|
||||||
install(
|
|
||||||
FILES convert_hf_to_gguf.py
|
|
||||||
PERMISSIONS
|
|
||||||
OWNER_READ
|
|
||||||
OWNER_WRITE
|
|
||||||
OWNER_EXECUTE
|
|
||||||
GROUP_READ
|
|
||||||
GROUP_EXECUTE
|
|
||||||
WORLD_READ
|
|
||||||
WORLD_EXECUTE
|
|
||||||
DESTINATION ${CMAKE_INSTALL_BINDIR})
|
|
||||||
|
|
||||||
configure_file(cmake/llama.pc.in
|
|
||||||
"${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
|
|
||||||
@ONLY)
|
|
||||||
|
|
||||||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
|
|
||||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
set( CMAKE_SYSTEM_NAME Darwin )
|
|
||||||
set( CMAKE_SYSTEM_PROCESSOR arm64 )
|
|
||||||
|
|
||||||
set( target arm64-apple-darwin-macho )
|
|
||||||
|
|
||||||
set( CMAKE_C_COMPILER clang )
|
|
||||||
set( CMAKE_CXX_COMPILER clang++ )
|
|
||||||
|
|
||||||
set( CMAKE_C_COMPILER_TARGET ${target} )
|
|
||||||
set( CMAKE_CXX_COMPILER_TARGET ${target} )
|
|
||||||
|
|
||||||
set( arch_c_flags "-march=armv8.4-a -fvectorize -ffp-model=fast -fno-finite-math-only" )
|
|
||||||
set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function" )
|
|
||||||
|
|
||||||
set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
|
||||||
set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
set( CMAKE_SYSTEM_NAME Windows )
|
|
||||||
set( CMAKE_SYSTEM_PROCESSOR arm64 )
|
|
||||||
|
|
||||||
set( target arm64-pc-windows-msvc )
|
|
||||||
|
|
||||||
set( CMAKE_C_COMPILER clang )
|
|
||||||
set( CMAKE_CXX_COMPILER clang++ )
|
|
||||||
|
|
||||||
set( CMAKE_C_COMPILER_TARGET ${target} )
|
|
||||||
set( CMAKE_CXX_COMPILER_TARGET ${target} )
|
|
||||||
|
|
||||||
set( arch_c_flags "-march=armv8.7-a -fvectorize -ffp-model=fast -fno-finite-math-only" )
|
|
||||||
set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function -Wno-gnu-zero-variadic-macro-arguments" )
|
|
||||||
|
|
||||||
set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
|
||||||
set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
set(BUILD_NUMBER 0)
|
|
||||||
set(BUILD_COMMIT "unknown")
|
|
||||||
set(BUILD_COMPILER "unknown")
|
|
||||||
set(BUILD_TARGET "unknown")
|
|
||||||
|
|
||||||
# Look for git
|
|
||||||
find_package(Git)
|
|
||||||
if(NOT Git_FOUND)
|
|
||||||
find_program(GIT_EXECUTABLE NAMES git git.exe)
|
|
||||||
if(GIT_EXECUTABLE)
|
|
||||||
set(Git_FOUND TRUE)
|
|
||||||
message(STATUS "Found Git: ${GIT_EXECUTABLE}")
|
|
||||||
else()
|
|
||||||
message(WARNING "Git not found. Build info will not be accurate.")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Get the commit count and hash
|
|
||||||
if(Git_FOUND)
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD
|
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
OUTPUT_VARIABLE HEAD
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
RESULT_VARIABLE RES
|
|
||||||
)
|
|
||||||
if (RES EQUAL 0)
|
|
||||||
set(BUILD_COMMIT ${HEAD})
|
|
||||||
endif()
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${GIT_EXECUTABLE} rev-list --count HEAD
|
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
OUTPUT_VARIABLE COUNT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
RESULT_VARIABLE RES
|
|
||||||
)
|
|
||||||
if (RES EQUAL 0)
|
|
||||||
set(BUILD_NUMBER ${COUNT})
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
|
|
||||||
|
|
||||||
if(CMAKE_VS_PLATFORM_NAME)
|
|
||||||
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
|
||||||
else()
|
|
||||||
set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}")
|
|
||||||
endif()
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
include("ggml/cmake/common.cmake")
|
|
||||||
|
|
||||||
function(llama_add_compile_flags)
|
|
||||||
if (LLAMA_FATAL_WARNINGS)
|
|
||||||
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
|
||||||
list(APPEND C_FLAGS -Werror)
|
|
||||||
list(APPEND CXX_FLAGS -Werror)
|
|
||||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
|
||||||
add_compile_options(/WX)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_ALL_WARNINGS)
|
|
||||||
if (NOT MSVC)
|
|
||||||
list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
|
|
||||||
-Werror=implicit-int -Werror=implicit-function-declaration)
|
|
||||||
|
|
||||||
list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
|
|
||||||
|
|
||||||
list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
|
|
||||||
|
|
||||||
list(APPEND C_FLAGS ${WARNING_FLAGS})
|
|
||||||
list(APPEND CXX_FLAGS ${WARNING_FLAGS})
|
|
||||||
|
|
||||||
ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
|
|
||||||
|
|
||||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${C_FLAGS};${GF_C_FLAGS}>"
|
|
||||||
"$<$<COMPILE_LANGUAGE:CXX>:${CXX_FLAGS};${GF_CXX_FLAGS}>")
|
|
||||||
else()
|
|
||||||
# todo : msvc
|
|
||||||
set(C_FLAGS "" PARENT_SCOPE)
|
|
||||||
set(CXX_FLAGS "" PARENT_SCOPE)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
find_package(Git)
|
|
||||||
|
|
||||||
# the commit's SHA1
|
|
||||||
execute_process(COMMAND
|
|
||||||
"${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
|
|
||||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
|
||||||
OUTPUT_VARIABLE GIT_SHA1
|
|
||||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
|
||||||
|
|
||||||
# the date of the commit
|
|
||||||
execute_process(COMMAND
|
|
||||||
"${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
|
|
||||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
|
||||||
OUTPUT_VARIABLE GIT_DATE
|
|
||||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
|
||||||
|
|
||||||
# the subject of the commit
|
|
||||||
execute_process(COMMAND
|
|
||||||
"${GIT_EXECUTABLE}" log -1 --format=%s
|
|
||||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
|
||||||
OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
|
|
||||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
define_property(GLOBAL PROPERTY LICENSE_TEXT
|
|
||||||
BRIEF_DOCS "Embedded licenses"
|
|
||||||
FULL_DOCS "Global string containing all aggregated licenses"
|
|
||||||
)
|
|
||||||
|
|
||||||
function(license_add_file NAME FILE)
|
|
||||||
if(NOT IS_ABSOLUTE "${FILE}")
|
|
||||||
set(FILE "${CMAKE_CURRENT_SOURCE_DIR}/${FILE}")
|
|
||||||
endif()
|
|
||||||
if(EXISTS "${FILE}")
|
|
||||||
set(TITLE "License for ${NAME}")
|
|
||||||
string(REGEX REPLACE "." "=" UNDERLINE "${TITLE}")
|
|
||||||
file(READ "${FILE}" TEXT)
|
|
||||||
get_property(TMP GLOBAL PROPERTY LICENSE_TEXT)
|
|
||||||
string(APPEND TMP "R\"=L=(${TITLE}\n${UNDERLINE}\n\n${TEXT})=L=\",\n")
|
|
||||||
set_property(GLOBAL PROPERTY LICENSE_TEXT "${TMP}")
|
|
||||||
else()
|
|
||||||
message(WARNING "License file '${FILE}' not found")
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
function(license_generate TARGET_NAME)
|
|
||||||
message(STATUS "Generating embedded license file for target: ${TARGET_NAME}")
|
|
||||||
get_property(TEXT GLOBAL PROPERTY LICENSE_TEXT)
|
|
||||||
|
|
||||||
set(CPP_CONTENT "// Generated by CMake\n\n")
|
|
||||||
string(APPEND CPP_CONTENT "const char* LICENSES[] = {\n")
|
|
||||||
string(APPEND CPP_CONTENT "${TEXT}")
|
|
||||||
string(APPEND CPP_CONTENT "nullptr\n")
|
|
||||||
string(APPEND CPP_CONTENT "};\n")
|
|
||||||
|
|
||||||
set(CPP_FILE "${CMAKE_BINARY_DIR}/license.cpp")
|
|
||||||
file(WRITE "${CPP_FILE}" "${CPP_CONTENT}")
|
|
||||||
|
|
||||||
if(TARGET ${TARGET_NAME})
|
|
||||||
target_sources(${TARGET_NAME} PRIVATE "${CPP_FILE}")
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Target '${TARGET_NAME}' does not exist")
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
set(LLAMA_VERSION @LLAMA_INSTALL_VERSION@)
|
|
||||||
set(LLAMA_BUILD_COMMIT @LLAMA_BUILD_COMMIT@)
|
|
||||||
set(LLAMA_BUILD_NUMBER @LLAMA_BUILD_NUMBER@)
|
|
||||||
set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@)
|
|
||||||
|
|
||||||
@PACKAGE_INIT@
|
|
||||||
|
|
||||||
set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@")
|
|
||||||
set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@")
|
|
||||||
set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@")
|
|
||||||
|
|
||||||
find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake)
|
|
||||||
|
|
||||||
find_library(llama_LIBRARY llama
|
|
||||||
REQUIRED
|
|
||||||
HINTS ${LLAMA_LIB_DIR}
|
|
||||||
NO_CMAKE_FIND_ROOT_PATH
|
|
||||||
)
|
|
||||||
|
|
||||||
add_library(llama UNKNOWN IMPORTED)
|
|
||||||
set_target_properties(llama
|
|
||||||
PROPERTIES
|
|
||||||
INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}"
|
|
||||||
INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;"
|
|
||||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
|
||||||
IMPORTED_LOCATION "${llama_LIBRARY}"
|
|
||||||
INTERFACE_COMPILE_FEATURES c_std_90
|
|
||||||
POSITION_INDEPENDENT_CODE ON)
|
|
||||||
|
|
||||||
check_required_components(Llama)
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
prefix=@CMAKE_INSTALL_PREFIX@
|
|
||||||
exec_prefix=@CMAKE_INSTALL_PREFIX@
|
|
||||||
libdir=@CMAKE_INSTALL_FULL_LIBDIR@
|
|
||||||
includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
|
|
||||||
|
|
||||||
Name: llama
|
|
||||||
Description: Port of Facebook's LLaMA model in C/C++
|
|
||||||
Version: @LLAMA_INSTALL_VERSION@
|
|
||||||
Libs: -L${libdir} -lggml -lggml-base -lllama
|
|
||||||
Cflags: -I${includedir}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
set(CMAKE_SYSTEM_NAME Linux)
|
|
||||||
set(CMAKE_SYSTEM_PROCESSOR riscv64)
|
|
||||||
set(CMAKE_SYSTEM_VERSION 1)
|
|
||||||
|
|
||||||
if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)")
|
|
||||||
message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}")
|
|
||||||
else()
|
|
||||||
set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple")
|
|
||||||
if (DEFINED ENV{RISCV_ROOT_PATH})
|
|
||||||
file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")
|
|
||||||
set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc)
|
|
||||||
set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++)
|
|
||||||
set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip)
|
|
||||||
set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu")
|
|
||||||
set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
|
||||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
|
||||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
|
||||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
|
||||||
set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}")
|
|
||||||
set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}")
|
|
||||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic")
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
set( CMAKE_SYSTEM_NAME Windows )
|
|
||||||
set( CMAKE_SYSTEM_PROCESSOR x86_64 )
|
|
||||||
|
|
||||||
set( CMAKE_C_COMPILER clang )
|
|
||||||
set( CMAKE_CXX_COMPILER clang++ )
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
# common
|
|
||||||
|
|
||||||
find_package(Threads REQUIRED)
|
|
||||||
|
|
||||||
llama_add_compile_flags()
|
|
||||||
|
|
||||||
# Build info header
|
|
||||||
#
|
|
||||||
|
|
||||||
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
|
|
||||||
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
|
|
||||||
|
|
||||||
# Is git submodule
|
|
||||||
if(NOT IS_DIRECTORY "${GIT_DIR}")
|
|
||||||
file(READ ${GIT_DIR} REAL_GIT_DIR_LINK)
|
|
||||||
string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK})
|
|
||||||
string(FIND "${REAL_GIT_DIR}" "/" SLASH_POS)
|
|
||||||
if (SLASH_POS EQUAL 0)
|
|
||||||
set(GIT_DIR "${REAL_GIT_DIR}")
|
|
||||||
else()
|
|
||||||
set(GIT_DIR "${PROJECT_SOURCE_DIR}/${REAL_GIT_DIR}")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(EXISTS "${GIT_DIR}/index")
|
|
||||||
# For build-info.cpp below
|
|
||||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${GIT_DIR}/index")
|
|
||||||
else()
|
|
||||||
message(WARNING "Git index not found in git repository.")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(TEMPLATE_FILE "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp.in")
|
|
||||||
set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/build-info.cpp")
|
|
||||||
configure_file(${TEMPLATE_FILE} ${OUTPUT_FILE})
|
|
||||||
|
|
||||||
set(TARGET build_info)
|
|
||||||
add_library(${TARGET} OBJECT ${OUTPUT_FILE})
|
|
||||||
if (BUILD_SHARED_LIBS)
|
|
||||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(TARGET common)
|
|
||||||
|
|
||||||
add_library(${TARGET} STATIC
|
|
||||||
arg.cpp
|
|
||||||
arg.h
|
|
||||||
base64.hpp
|
|
||||||
chat-parser.cpp
|
|
||||||
chat-parser.h
|
|
||||||
chat-parser-xml-toolcall.h
|
|
||||||
chat-parser-xml-toolcall.cpp
|
|
||||||
chat-peg-parser.cpp
|
|
||||||
chat-peg-parser.h
|
|
||||||
chat.cpp
|
|
||||||
chat.h
|
|
||||||
common.cpp
|
|
||||||
common.h
|
|
||||||
console.cpp
|
|
||||||
console.h
|
|
||||||
download.cpp
|
|
||||||
download.h
|
|
||||||
http.h
|
|
||||||
json-partial.cpp
|
|
||||||
json-partial.h
|
|
||||||
json-schema-to-grammar.cpp
|
|
||||||
llguidance.cpp
|
|
||||||
log.cpp
|
|
||||||
log.h
|
|
||||||
ngram-cache.cpp
|
|
||||||
ngram-cache.h
|
|
||||||
peg-parser.cpp
|
|
||||||
peg-parser.h
|
|
||||||
preset.cpp
|
|
||||||
preset.h
|
|
||||||
regex-partial.cpp
|
|
||||||
regex-partial.h
|
|
||||||
sampling.cpp
|
|
||||||
sampling.h
|
|
||||||
speculative.cpp
|
|
||||||
speculative.h
|
|
||||||
unicode.cpp
|
|
||||||
unicode.h
|
|
||||||
)
|
|
||||||
|
|
||||||
target_include_directories(${TARGET} PUBLIC . ../vendor)
|
|
||||||
target_compile_features (${TARGET} PUBLIC cxx_std_17)
|
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
|
||||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
|
||||||
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
|
||||||
|
|
||||||
if (LLAMA_CURL)
|
|
||||||
# Use curl to download model url
|
|
||||||
find_package(CURL)
|
|
||||||
if (NOT CURL_FOUND)
|
|
||||||
message(FATAL_ERROR "Could NOT find CURL. Hint: to disable this feature, set -DLLAMA_CURL=OFF")
|
|
||||||
endif()
|
|
||||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
|
|
||||||
include_directories(${CURL_INCLUDE_DIRS})
|
|
||||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
|
|
||||||
elseif (LLAMA_HTTPLIB)
|
|
||||||
# otherwise, use cpp-httplib
|
|
||||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
|
||||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_LLGUIDANCE)
|
|
||||||
include(ExternalProject)
|
|
||||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
|
||||||
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
|
||||||
|
|
||||||
# Set the correct library file extension based on platform
|
|
||||||
if (WIN32)
|
|
||||||
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
|
|
||||||
# Add Windows-specific libraries
|
|
||||||
set(LLGUIDANCE_PLATFORM_LIBS
|
|
||||||
ws2_32 # Windows Sockets API
|
|
||||||
userenv # For GetUserProfileDirectoryW
|
|
||||||
ntdll # For NT functions
|
|
||||||
bcrypt # For BCryptGenRandom
|
|
||||||
)
|
|
||||||
else()
|
|
||||||
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
|
|
||||||
set(LLGUIDANCE_PLATFORM_LIBS "")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
ExternalProject_Add(llguidance_ext
|
|
||||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
|
||||||
# v1.0.1:
|
|
||||||
GIT_TAG d795912fedc7d393de740177ea9ea761e7905774
|
|
||||||
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
|
||||||
SOURCE_DIR ${LLGUIDANCE_SRC}
|
|
||||||
BUILD_IN_SOURCE TRUE
|
|
||||||
CONFIGURE_COMMAND ""
|
|
||||||
BUILD_COMMAND cargo build --release --package llguidance
|
|
||||||
INSTALL_COMMAND ""
|
|
||||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h
|
|
||||||
UPDATE_COMMAND ""
|
|
||||||
)
|
|
||||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
|
|
||||||
|
|
||||||
add_library(llguidance STATIC IMPORTED)
|
|
||||||
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME})
|
|
||||||
add_dependencies(llguidance llguidance_ext)
|
|
||||||
|
|
||||||
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
|
||||||
# Add platform libraries to the main target
|
|
||||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,131 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <cstring>
|
|
||||||
|
|
||||||
// pseudo-env variable to identify preset-only arguments
|
|
||||||
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
|
|
||||||
#define COMMON_ARG_PRESET_STOP_TIMEOUT "__PRESET_STOP_TIMEOUT"
|
|
||||||
|
|
||||||
//
|
|
||||||
// CLI argument parsing
|
|
||||||
//
|
|
||||||
|
|
||||||
struct common_arg {
|
|
||||||
std::set<enum llama_example> examples = {LLAMA_EXAMPLE_COMMON};
|
|
||||||
std::set<enum llama_example> excludes = {};
|
|
||||||
std::vector<const char *> args;
|
|
||||||
std::vector<const char *> args_neg; // for negated args like --no-xxx
|
|
||||||
const char * value_hint = nullptr; // help text or example for arg value
|
|
||||||
const char * value_hint_2 = nullptr; // for second arg value
|
|
||||||
const char * env = nullptr;
|
|
||||||
std::string help;
|
|
||||||
bool is_sparam = false; // is current arg a sampling param?
|
|
||||||
bool is_preset_only = false; // is current arg preset-only (not treated as CLI arg)
|
|
||||||
void (*handler_void) (common_params & params) = nullptr;
|
|
||||||
void (*handler_string) (common_params & params, const std::string &) = nullptr;
|
|
||||||
void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr;
|
|
||||||
void (*handler_int) (common_params & params, int) = nullptr;
|
|
||||||
void (*handler_bool) (common_params & params, bool) = nullptr;
|
|
||||||
|
|
||||||
common_arg() = default;
|
|
||||||
|
|
||||||
common_arg(
|
|
||||||
const std::initializer_list<const char *> & args,
|
|
||||||
const char * value_hint,
|
|
||||||
const std::string & help,
|
|
||||||
void (*handler)(common_params & params, const std::string &)
|
|
||||||
) : args(args), value_hint(value_hint), help(help), handler_string(handler) {}
|
|
||||||
|
|
||||||
common_arg(
|
|
||||||
const std::initializer_list<const char *> & args,
|
|
||||||
const char * value_hint,
|
|
||||||
const std::string & help,
|
|
||||||
void (*handler)(common_params & params, int)
|
|
||||||
) : args(args), value_hint(value_hint), help(help), handler_int(handler) {}
|
|
||||||
|
|
||||||
common_arg(
|
|
||||||
const std::initializer_list<const char *> & args,
|
|
||||||
const std::string & help,
|
|
||||||
void (*handler)(common_params & params)
|
|
||||||
) : args(args), help(help), handler_void(handler) {}
|
|
||||||
|
|
||||||
common_arg(
|
|
||||||
const std::initializer_list<const char *> & args,
|
|
||||||
const std::initializer_list<const char *> & args_neg,
|
|
||||||
const std::string & help,
|
|
||||||
void (*handler)(common_params & params, bool)
|
|
||||||
) : args(args), args_neg(args_neg), help(help), handler_bool(handler) {}
|
|
||||||
|
|
||||||
// support 2 values for arg
|
|
||||||
common_arg(
|
|
||||||
const std::initializer_list<const char *> & args,
|
|
||||||
const char * value_hint,
|
|
||||||
const char * value_hint_2,
|
|
||||||
const std::string & help,
|
|
||||||
void (*handler)(common_params & params, const std::string &, const std::string &)
|
|
||||||
) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
|
|
||||||
|
|
||||||
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
|
|
||||||
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
|
|
||||||
common_arg & set_env(const char * env);
|
|
||||||
common_arg & set_sparam();
|
|
||||||
common_arg & set_preset_only();
|
|
||||||
bool in_example(enum llama_example ex);
|
|
||||||
bool is_exclude(enum llama_example ex);
|
|
||||||
bool get_value_from_env(std::string & output) const;
|
|
||||||
bool has_value_from_env() const;
|
|
||||||
std::string to_string() const;
|
|
||||||
|
|
||||||
// for using as key in std::map
|
|
||||||
bool operator<(const common_arg& other) const {
|
|
||||||
if (args.empty() || other.args.empty()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return strcmp(args[0], other.args[0]) < 0;
|
|
||||||
}
|
|
||||||
bool operator==(const common_arg& other) const {
|
|
||||||
if (args.empty() || other.args.empty()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return strcmp(args[0], other.args[0]) == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get all args and env vars (including negated args/env)
|
|
||||||
std::vector<std::string> get_args() const;
|
|
||||||
std::vector<std::string> get_env() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace common_arg_utils {
|
|
||||||
bool is_truthy(const std::string & value);
|
|
||||||
bool is_falsey(const std::string & value);
|
|
||||||
bool is_autoy(const std::string & value);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct common_params_context {
|
|
||||||
enum llama_example ex = LLAMA_EXAMPLE_COMMON;
|
|
||||||
common_params & params;
|
|
||||||
std::vector<common_arg> options;
|
|
||||||
void(*print_usage)(int, char **) = nullptr;
|
|
||||||
common_params_context(common_params & params) : params(params) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
// parse input arguments from CLI
|
|
||||||
// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
|
|
||||||
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
|
||||||
|
|
||||||
// parse input arguments from CLI into a map
|
|
||||||
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
|
|
||||||
|
|
||||||
// populate preset-only arguments
|
|
||||||
// these arguments are not treated as command line arguments
|
|
||||||
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
|
||||||
void common_params_add_preset_options(std::vector<common_arg> & args);
|
|
||||||
|
|
||||||
// initialize argument parser context - used by test-arg-parser and preset
|
|
||||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
|
||||||
@@ -1,392 +0,0 @@
|
|||||||
/*
|
|
||||||
This is free and unencumbered software released into the public domain.
|
|
||||||
|
|
||||||
Anyone is free to copy, modify, publish, use, compile, sell, or
|
|
||||||
distribute this software, either in source code form or as a compiled
|
|
||||||
binary, for any purpose, commercial or non-commercial, and by any
|
|
||||||
means.
|
|
||||||
|
|
||||||
In jurisdictions that recognize copyright laws, the author or authors
|
|
||||||
of this software dedicate any and all copyright interest in the
|
|
||||||
software to the public domain. We make this dedication for the benefit
|
|
||||||
of the public at large and to the detriment of our heirs and
|
|
||||||
successors. We intend this dedication to be an overt act of
|
|
||||||
relinquishment in perpetuity of all present and future rights to this
|
|
||||||
software under copyright law.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
||||||
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
|
||||||
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
|
||||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
|
||||||
OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
For more information, please refer to <http://unlicense.org>
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef PUBLIC_DOMAIN_BASE64_HPP_
|
|
||||||
#define PUBLIC_DOMAIN_BASE64_HPP_
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <iterator>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
class base64_error : public std::runtime_error
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
using std::runtime_error::runtime_error;
|
|
||||||
};
|
|
||||||
|
|
||||||
class base64
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
enum class alphabet
|
|
||||||
{
|
|
||||||
/** the alphabet is detected automatically */
|
|
||||||
auto_,
|
|
||||||
/** the standard base64 alphabet is used */
|
|
||||||
standard,
|
|
||||||
/** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/
|
|
||||||
url_filename_safe
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class decoding_behavior
|
|
||||||
{
|
|
||||||
/** if the input is not padded, the remaining bits are ignored */
|
|
||||||
moderate,
|
|
||||||
/** if a padding character is encounter decoding is finished */
|
|
||||||
loose
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
Encodes all the elements from `in_begin` to `in_end` to `out`.
|
|
||||||
|
|
||||||
@warning The source and destination cannot overlap. The destination must be able to hold at least
|
|
||||||
`required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator.
|
|
||||||
|
|
||||||
@tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than
|
|
||||||
8 bits
|
|
||||||
@tparam Output_iterator the destination; the elements written to it are from the type `char`
|
|
||||||
@param in_begin the beginning of the source
|
|
||||||
@param in_end the ending of the source
|
|
||||||
@param out the destination iterator
|
|
||||||
@param alphabet which alphabet should be used
|
|
||||||
@returns the iterator to the next element past the last element copied
|
|
||||||
@throws see `Input_iterator` and `Output_iterator`
|
|
||||||
*/
|
|
||||||
template<typename Input_iterator, typename Output_iterator>
|
|
||||||
static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
|
|
||||||
alphabet alphabet = alphabet::standard)
|
|
||||||
{
|
|
||||||
constexpr auto pad = '=';
|
|
||||||
const char* alpha = alphabet == alphabet::url_filename_safe
|
|
||||||
? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
|
||||||
: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
||||||
|
|
||||||
while (in_begin != in_end) {
|
|
||||||
std::uint8_t i0 = 0, i1 = 0, i2 = 0;
|
|
||||||
|
|
||||||
// first character
|
|
||||||
i0 = static_cast<std::uint8_t>(*in_begin);
|
|
||||||
++in_begin;
|
|
||||||
|
|
||||||
*out = alpha[i0 >> 2 & 0x3f];
|
|
||||||
++out;
|
|
||||||
|
|
||||||
// part of first character and second
|
|
||||||
if (in_begin != in_end) {
|
|
||||||
i1 = static_cast<std::uint8_t>(*in_begin);
|
|
||||||
++in_begin;
|
|
||||||
|
|
||||||
*out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)];
|
|
||||||
++out;
|
|
||||||
} else {
|
|
||||||
*out = alpha[(i0 & 0x3) << 4];
|
|
||||||
++out;
|
|
||||||
|
|
||||||
// last padding
|
|
||||||
*out = pad;
|
|
||||||
++out;
|
|
||||||
|
|
||||||
// last padding
|
|
||||||
*out = pad;
|
|
||||||
++out;
|
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// part of second character and third
|
|
||||||
if (in_begin != in_end) {
|
|
||||||
i2 = static_cast<std::uint8_t>(*in_begin);
|
|
||||||
++in_begin;
|
|
||||||
|
|
||||||
*out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)];
|
|
||||||
++out;
|
|
||||||
} else {
|
|
||||||
*out = alpha[(i1 & 0xf) << 2];
|
|
||||||
++out;
|
|
||||||
|
|
||||||
// last padding
|
|
||||||
*out = pad;
|
|
||||||
++out;
|
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// rest of third
|
|
||||||
*out = alpha[i2 & 0x3f];
|
|
||||||
++out;
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Encodes a string.
|
|
||||||
|
|
||||||
@param str the string that should be encoded
|
|
||||||
@param alphabet which alphabet should be used
|
|
||||||
@returns the encoded base64 string
|
|
||||||
@throws see base64::encode()
|
|
||||||
*/
|
|
||||||
static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard)
|
|
||||||
{
|
|
||||||
std::string result;
|
|
||||||
|
|
||||||
result.reserve(required_encode_size(str.length()) + 1);
|
|
||||||
|
|
||||||
encode(str.begin(), str.end(), std::back_inserter(result), alphabet);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Encodes a char array.
|
|
||||||
|
|
||||||
@param buffer the char array
|
|
||||||
@param size the size of the array
|
|
||||||
@param alphabet which alphabet should be used
|
|
||||||
@returns the encoded string
|
|
||||||
*/
|
|
||||||
static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard)
|
|
||||||
{
|
|
||||||
std::string result;
|
|
||||||
|
|
||||||
result.reserve(required_encode_size(size) + 1);
|
|
||||||
|
|
||||||
encode(buffer, buffer + size, std::back_inserter(result), alphabet);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`,
|
|
||||||
in other words: inplace decoding is possible.
|
|
||||||
|
|
||||||
@warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`,
|
|
||||||
otherwise the behavior depends on the output iterator.
|
|
||||||
|
|
||||||
@tparam Input_iterator the source; the returned elements are cast to `char`
|
|
||||||
@tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t`
|
|
||||||
@param in_begin the beginning of the source
|
|
||||||
@param in_end the ending of the source
|
|
||||||
@param out the destination iterator
|
|
||||||
@param alphabet which alphabet should be used
|
|
||||||
@param behavior the behavior when an error was detected
|
|
||||||
@returns the iterator to the next element past the last element copied
|
|
||||||
@throws base64_error depending on the set behavior
|
|
||||||
@throws see `Input_iterator` and `Output_iterator`
|
|
||||||
*/
|
|
||||||
template<typename Input_iterator, typename Output_iterator>
|
|
||||||
static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
|
|
||||||
alphabet alphabet = alphabet::auto_,
|
|
||||||
decoding_behavior behavior = decoding_behavior::moderate)
|
|
||||||
{
|
|
||||||
//constexpr auto pad = '=';
|
|
||||||
std::uint8_t last = 0;
|
|
||||||
auto bits = 0;
|
|
||||||
|
|
||||||
while (in_begin != in_end) {
|
|
||||||
auto c = *in_begin;
|
|
||||||
++in_begin;
|
|
||||||
|
|
||||||
if (c == '=') {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto part = _base64_value(alphabet, c);
|
|
||||||
|
|
||||||
// enough bits for one byte
|
|
||||||
if (bits + 6 >= 8) {
|
|
||||||
*out = (last << (8 - bits)) | (part >> (bits - 2));
|
|
||||||
++out;
|
|
||||||
|
|
||||||
bits -= 2;
|
|
||||||
} else {
|
|
||||||
bits += 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
last = part;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check padding
|
|
||||||
if (behavior != decoding_behavior::loose) {
|
|
||||||
while (in_begin != in_end) {
|
|
||||||
auto c = *in_begin;
|
|
||||||
++in_begin;
|
|
||||||
|
|
||||||
if (c != '=') {
|
|
||||||
throw base64_error("invalid base64 character.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Decodes a string.
|
|
||||||
|
|
||||||
@param str the base64 encoded string
|
|
||||||
@param alphabet which alphabet should be used
|
|
||||||
@param behavior the behavior when an error was detected
|
|
||||||
@returns the decoded string
|
|
||||||
@throws see base64::decode()
|
|
||||||
*/
|
|
||||||
static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_,
|
|
||||||
decoding_behavior behavior = decoding_behavior::moderate)
|
|
||||||
{
|
|
||||||
std::string result;
|
|
||||||
|
|
||||||
result.reserve(max_decode_size(str.length()));
|
|
||||||
|
|
||||||
decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Decodes a string.
|
|
||||||
|
|
||||||
@param buffer the base64 encoded buffer
|
|
||||||
@param size the size of the buffer
|
|
||||||
@param alphabet which alphabet should be used
|
|
||||||
@param behavior the behavior when an error was detected
|
|
||||||
@returns the decoded string
|
|
||||||
@throws see base64::decode()
|
|
||||||
*/
|
|
||||||
static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_,
|
|
||||||
decoding_behavior behavior = decoding_behavior::moderate)
|
|
||||||
{
|
|
||||||
std::string result;
|
|
||||||
|
|
||||||
result.reserve(max_decode_size(size));
|
|
||||||
|
|
||||||
decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Decodes a string inplace.
|
|
||||||
|
|
||||||
@param[in,out] str the base64 encoded string
|
|
||||||
@param alphabet which alphabet should be used
|
|
||||||
@param behavior the behavior when an error was detected
|
|
||||||
@throws base64::decode_inplace()
|
|
||||||
*/
|
|
||||||
static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_,
|
|
||||||
decoding_behavior behavior = decoding_behavior::moderate)
|
|
||||||
{
|
|
||||||
str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin());
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Decodes a char array inplace.
|
|
||||||
|
|
||||||
@param[in,out] str the string array
|
|
||||||
@param size the length of the array
|
|
||||||
@param alphabet which alphabet should be used
|
|
||||||
@param behavior the behavior when an error was detected
|
|
||||||
@returns the pointer to the next element past the last element decoded
|
|
||||||
@throws base64::decode_inplace()
|
|
||||||
*/
|
|
||||||
static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_,
|
|
||||||
decoding_behavior behavior = decoding_behavior::moderate)
|
|
||||||
{
|
|
||||||
return decode(str, str + size, str, alphabet, behavior);
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Returns the required decoding size for a given size. The value is calculated with the following formula:
|
|
||||||
|
|
||||||
$$
|
|
||||||
\lceil \frac{size}{4} \rceil \cdot 3
|
|
||||||
$$
|
|
||||||
|
|
||||||
@param size the size of the encoded input
|
|
||||||
@returns the size of the resulting decoded buffer; this the absolute maximum
|
|
||||||
*/
|
|
||||||
static std::size_t max_decode_size(std::size_t size) noexcept
|
|
||||||
{
|
|
||||||
return (size / 4 + (size % 4 ? 1 : 0)) * 3;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
Returns the required encoding size for a given size. The value is calculated with the following formula:
|
|
||||||
|
|
||||||
$$
|
|
||||||
\lceil \frac{size}{3} \rceil \cdot 4
|
|
||||||
$$
|
|
||||||
|
|
||||||
@param size the size of the decoded input
|
|
||||||
@returns the size of the resulting encoded buffer
|
|
||||||
*/
|
|
||||||
static std::size_t required_encode_size(std::size_t size) noexcept
|
|
||||||
{
|
|
||||||
return (size / 3 + (size % 3 ? 1 : 0)) * 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
static std::uint8_t _base64_value(alphabet& alphabet, char c)
|
|
||||||
{
|
|
||||||
if (c >= 'A' && c <= 'Z') {
|
|
||||||
return c - 'A';
|
|
||||||
} else if (c >= 'a' && c <= 'z') {
|
|
||||||
return c - 'a' + 26;
|
|
||||||
} else if (c >= '0' && c <= '9') {
|
|
||||||
return c - '0' + 52;
|
|
||||||
}
|
|
||||||
|
|
||||||
// comes down to alphabet
|
|
||||||
if (alphabet == alphabet::standard) {
|
|
||||||
if (c == '+') {
|
|
||||||
return 62;
|
|
||||||
} else if (c == '/') {
|
|
||||||
return 63;
|
|
||||||
}
|
|
||||||
} else if (alphabet == alphabet::url_filename_safe) {
|
|
||||||
if (c == '-') {
|
|
||||||
return 62;
|
|
||||||
} else if (c == '_') {
|
|
||||||
return 63;
|
|
||||||
}
|
|
||||||
} // auto detect
|
|
||||||
else {
|
|
||||||
if (c == '+') {
|
|
||||||
alphabet = alphabet::standard;
|
|
||||||
|
|
||||||
return 62;
|
|
||||||
} else if (c == '/') {
|
|
||||||
alphabet = alphabet::standard;
|
|
||||||
|
|
||||||
return 63;
|
|
||||||
} else if (c == '-') {
|
|
||||||
alphabet = alphabet::url_filename_safe;
|
|
||||||
|
|
||||||
return 62;
|
|
||||||
} else if (c == '_') {
|
|
||||||
alphabet = alphabet::url_filename_safe;
|
|
||||||
|
|
||||||
return 63;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw base64_error("invalid base64 character.");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // !PUBLIC_DOMAIN_BASE64_HPP_
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
int LLAMA_BUILD_NUMBER = @LLAMA_BUILD_NUMBER@;
|
|
||||||
char const *LLAMA_COMMIT = "@LLAMA_BUILD_COMMIT@";
|
|
||||||
char const *LLAMA_COMPILER = "@BUILD_COMPILER@";
|
|
||||||
char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@";
|
|
||||||
@@ -1,879 +0,0 @@
|
|||||||
#include "chat.h"
|
|
||||||
#include "chat-parser.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "json-partial.h"
|
|
||||||
#include "json-schema-to-grammar.h"
|
|
||||||
#include "log.h"
|
|
||||||
#include "regex-partial.h"
|
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
|
||||||
|
|
||||||
class xml_toolcall_syntax_exception : public std::runtime_error {
|
|
||||||
public:
|
|
||||||
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline void sort_uniq(std::vector<T> &vec) {
|
|
||||||
std::sort(vec.begin(), vec.end());
|
|
||||||
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline bool all_space(const T &str) {
|
|
||||||
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t utf8_truncate_safe(const std::string_view s) {
|
|
||||||
size_t len = s.size();
|
|
||||||
if (len == 0) return 0;
|
|
||||||
size_t i = len;
|
|
||||||
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
|
||||||
--i;
|
|
||||||
unsigned char c = s[i];
|
|
||||||
if ((c & 0x80) == 0) {
|
|
||||||
return len;
|
|
||||||
} else if ((c & 0xC0) == 0xC0) {
|
|
||||||
size_t expected_len = 0;
|
|
||||||
if ((c & 0xE0) == 0xC0) expected_len = 2;
|
|
||||||
else if ((c & 0xF0) == 0xE0) expected_len = 3;
|
|
||||||
else if ((c & 0xF8) == 0xF0) expected_len = 4;
|
|
||||||
else return i;
|
|
||||||
if (len - i >= expected_len) {
|
|
||||||
return len;
|
|
||||||
} else {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len - std::min(len, size_t(3));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void utf8_truncate_safe_resize(std::string &s) {
|
|
||||||
s.resize(utf8_truncate_safe(s));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
|
|
||||||
return s.substr(0, utf8_truncate_safe(s));
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
|
|
||||||
if (literal1.size() == 0) return builder.try_find_literal(literal2);
|
|
||||||
const auto saved_pos = builder.pos();
|
|
||||||
while (auto res = builder.try_find_literal(literal1)) {
|
|
||||||
builder.consume_spaces();
|
|
||||||
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
|
|
||||||
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
|
|
||||||
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
|
|
||||||
res->prelude = builder.str({saved_pos, res->groups[0].begin});
|
|
||||||
}
|
|
||||||
builder.move_to(builder.pos() + match_len);
|
|
||||||
res->groups[0].end = builder.pos();
|
|
||||||
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
builder.move_to(res->groups[0].begin + 1);
|
|
||||||
}
|
|
||||||
builder.move_to(saved_pos);
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* make a GBNF that accept any strings except those containing any of the forbidden strings.
|
|
||||||
*/
|
|
||||||
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
|
|
||||||
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
|
|
||||||
if (c == '\\' || c == ']' || c == '^' || c == '-') {
|
|
||||||
std::string s = "\\";
|
|
||||||
s.push_back((char)c);
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
if (isprint(c)) {
|
|
||||||
return std::string(1, (char)c);
|
|
||||||
}
|
|
||||||
char buf[16];
|
|
||||||
snprintf(buf, 15, "\\x%02X", c);
|
|
||||||
return std::string(buf);
|
|
||||||
};
|
|
||||||
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
|
|
||||||
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
|
|
||||||
int i = l;
|
|
||||||
while (i < r) {
|
|
||||||
const std::string &s = forbids[i];
|
|
||||||
if ((int)s.size() == depth) {
|
|
||||||
++i;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
unsigned char c = (unsigned char)s[depth];
|
|
||||||
int j = i;
|
|
||||||
while (j < r && (int)forbids[j].size() > depth &&
|
|
||||||
(unsigned char)forbids[j][depth] == c) {
|
|
||||||
++j;
|
|
||||||
}
|
|
||||||
children.push_back({c, {i, j}});
|
|
||||||
i = j;
|
|
||||||
}
|
|
||||||
std::vector<std::string> alts;
|
|
||||||
if (!children.empty()) {
|
|
||||||
std::string cls;
|
|
||||||
for (auto &ch : children) cls += charclass_escape(ch.first);
|
|
||||||
alts.push_back(std::string("[^") + cls + "]");
|
|
||||||
}
|
|
||||||
for (auto &ch : children) {
|
|
||||||
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
|
|
||||||
if (!childExpr.empty()) {
|
|
||||||
std::string quoted_ch = "\"";
|
|
||||||
if (ch.first == '\\') quoted_ch += "\\\\";
|
|
||||||
else if (ch.first == '"') quoted_ch += "\\\"";
|
|
||||||
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
|
|
||||||
else {
|
|
||||||
char buf[16];
|
|
||||||
snprintf(buf, 15, "\\x%02X", ch.first);
|
|
||||||
quoted_ch += buf;
|
|
||||||
}
|
|
||||||
quoted_ch += "\"";
|
|
||||||
std::string branch = quoted_ch + std::string(" ") + childExpr;
|
|
||||||
alts.push_back(branch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (alts.empty()) return "";
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << "( ";
|
|
||||||
for (size_t k = 0; k < alts.size(); ++k) {
|
|
||||||
if (k) oss << " | ";
|
|
||||||
oss << alts[k];
|
|
||||||
}
|
|
||||||
oss << " )";
|
|
||||||
return oss.str();
|
|
||||||
};
|
|
||||||
if (forbids.empty()) return "( . )*";
|
|
||||||
sort(forbids.begin(), forbids.end());
|
|
||||||
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
|
|
||||||
if (expr.empty()) {
|
|
||||||
std::string cls;
|
|
||||||
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
|
|
||||||
expr = std::string("( [^") + cls + "] )";
|
|
||||||
}
|
|
||||||
if (forbids.size() == 1)
|
|
||||||
return expr + "*";
|
|
||||||
else
|
|
||||||
return std::string("( ") + expr + " )*";
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build grammar for xml-style tool call
|
|
||||||
* form.scope_start and form.scope_end can be empty.
|
|
||||||
* Requires data.format for model-specific hacks.
|
|
||||||
*/
|
|
||||||
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
|
|
||||||
GGML_ASSERT(!form.tool_start.empty());
|
|
||||||
GGML_ASSERT(!form.tool_sep.empty());
|
|
||||||
GGML_ASSERT(!form.key_start.empty());
|
|
||||||
GGML_ASSERT(!form.val_end.empty());
|
|
||||||
GGML_ASSERT(!form.tool_end.empty());
|
|
||||||
|
|
||||||
std::string key_val_sep = form.key_val_sep;
|
|
||||||
if (form.key_val_sep2) {
|
|
||||||
key_val_sep += "\n";
|
|
||||||
key_val_sep += *form.key_val_sep2;
|
|
||||||
}
|
|
||||||
GGML_ASSERT(!key_val_sep.empty());
|
|
||||||
|
|
||||||
if (tools.is_array() && !tools.empty()) {
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
|
|
||||||
auto string_arg_val = form.last_val_end ?
|
|
||||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
|
|
||||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
|
|
||||||
|
|
||||||
std::vector<std::string> tool_rules;
|
|
||||||
for (const auto & tool : tools) {
|
|
||||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
|
||||||
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const auto & function = tool.at("function");
|
|
||||||
if (!function.contains("name") || !function.at("name").is_string()) {
|
|
||||||
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
|
||||||
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::string name = function.at("name");
|
|
||||||
auto parameters = function.at("parameters");
|
|
||||||
builder.resolve_refs(parameters);
|
|
||||||
|
|
||||||
struct parameter_rule {
|
|
||||||
std::string symbol_name;
|
|
||||||
bool is_required;
|
|
||||||
};
|
|
||||||
std::vector<parameter_rule> arg_rules;
|
|
||||||
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
|
|
||||||
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
std::vector<std::string> requiredParameters;
|
|
||||||
if (parameters.contains("required")) {
|
|
||||||
try { parameters.at("required").get_to(requiredParameters); }
|
|
||||||
catch (const std::runtime_error&) {
|
|
||||||
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sort_uniq(requiredParameters);
|
|
||||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
|
||||||
std::string quoted_key = key;
|
|
||||||
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
|
|
||||||
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
|
|
||||||
quoted_key = gbnf_format_literal(key);
|
|
||||||
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
|
|
||||||
}
|
|
||||||
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
|
|
||||||
gbnf_format_literal(form.key_start) + " " +
|
|
||||||
gbnf_format_literal(quoted_key) + " " +
|
|
||||||
gbnf_format_literal(key_val_sep) + " " +
|
|
||||||
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
|
|
||||||
(form.raw_argval ?
|
|
||||||
string_arg_val :
|
|
||||||
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
|
|
||||||
) :
|
|
||||||
builder.add_schema(name + "-arg-" + key, value)
|
|
||||||
)
|
|
||||||
), required});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
|
|
||||||
decltype(next_arg_with_sep) next_arg = "\"\"";
|
|
||||||
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
|
|
||||||
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
|
|
||||||
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
|
|
||||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
|
|
||||||
);
|
|
||||||
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
|
|
||||||
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
|
|
||||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string quoted_name = name;
|
|
||||||
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
|
|
||||||
quoted_name = gbnf_format_literal(name);
|
|
||||||
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
|
|
||||||
}
|
|
||||||
quoted_name = gbnf_format_literal(quoted_name);
|
|
||||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
|
||||||
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
|
||||||
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
|
|
||||||
}
|
|
||||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
|
||||||
gbnf_format_literal(form.tool_start) + " " +
|
|
||||||
quoted_name + " " +
|
|
||||||
gbnf_format_literal(form.tool_sep) + " " +
|
|
||||||
next_arg
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
|
|
||||||
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
|
|
||||||
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
|
|
||||||
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
|
|
||||||
builder.add_rule("root",
|
|
||||||
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
|
|
||||||
tool_call_multiple_with_end + "?" +
|
|
||||||
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
// grammar trigger for tool call
|
|
||||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
|
||||||
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
|
|
||||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
|
||||||
*/
|
|
||||||
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
|
|
||||||
GGML_ASSERT(!form.tool_start.empty());
|
|
||||||
GGML_ASSERT(!form.key_start.empty());
|
|
||||||
GGML_ASSERT(!form.key_val_sep.empty());
|
|
||||||
GGML_ASSERT(!form.val_end.empty());
|
|
||||||
GGML_ASSERT(!form.tool_end.empty());
|
|
||||||
|
|
||||||
// Helper to choose return false or throw error
|
|
||||||
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
|
|
||||||
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
|
|
||||||
if (recovery) {
|
|
||||||
builder.move_to(start_pos);
|
|
||||||
return false;
|
|
||||||
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
|
|
||||||
};
|
|
||||||
// Drop substring from needle to end from a JSON
|
|
||||||
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
|
|
||||||
auto pos = json_str.rfind(needle);
|
|
||||||
if (pos == std::string::npos) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
|
|
||||||
unsigned char ch = static_cast<unsigned char>(json_str[i]);
|
|
||||||
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (pos != 0 && json_str[pos - 1] == '"') {
|
|
||||||
--pos;
|
|
||||||
}
|
|
||||||
json_str.resize(pos);
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
// Helper to generate a partial argument JSON
|
|
||||||
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
|
|
||||||
auto rest = builder.consume_rest();
|
|
||||||
utf8_truncate_safe_resize(rest);
|
|
||||||
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
|
|
||||||
auto tool_str = arguments.dump();
|
|
||||||
if (partial_json(tool_str)) {
|
|
||||||
if (builder.add_tool_call(function_name, "", tool_str)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
|
|
||||||
};
|
|
||||||
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
|
|
||||||
constexpr auto try_find_close = [](
|
|
||||||
common_chat_msg_parser & builder,
|
|
||||||
const std::string & end,
|
|
||||||
const std::optional<std::string> & alt_end,
|
|
||||||
const std::string & end_next,
|
|
||||||
const std::optional<std::string> & alt_end_next
|
|
||||||
) {
|
|
||||||
auto saved_pos = builder.pos();
|
|
||||||
auto tc = builder.try_find_literal(end);
|
|
||||||
auto val_end_size = end.size();
|
|
||||||
if (alt_end) {
|
|
||||||
auto pos_1 = builder.pos();
|
|
||||||
builder.move_to(saved_pos);
|
|
||||||
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
|
|
||||||
if (alt_end_next) {
|
|
||||||
builder.move_to(saved_pos);
|
|
||||||
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
|
|
||||||
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
|
|
||||||
tc2 = tc3;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
|
|
||||||
tc = tc2;
|
|
||||||
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
|
|
||||||
builder.move_to(tc->groups[0].end);
|
|
||||||
val_end_size = alt_end->size();
|
|
||||||
} else {
|
|
||||||
builder.move_to(pos_1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::make_pair(val_end_size, tc);
|
|
||||||
};
|
|
||||||
// Helper to find a val_end or last_val_end, returns matched pattern size
|
|
||||||
const auto try_find_val_end = [try_find_close, &builder, &form]() {
|
|
||||||
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
|
|
||||||
};
|
|
||||||
// Helper to find a tool_end or last_tool_end, returns matched pattern size
|
|
||||||
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
|
|
||||||
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
|
|
||||||
};
|
|
||||||
|
|
||||||
bool recovery = true;
|
|
||||||
const auto start_pos = builder.pos();
|
|
||||||
if (!all_space(form.scope_start)) {
|
|
||||||
if (auto tc = builder.try_find_literal(form.scope_start)) {
|
|
||||||
if (all_space(tc->prelude)) {
|
|
||||||
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
|
|
||||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
|
|
||||||
} else {
|
|
||||||
builder.move_to(start_pos);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else return false;
|
|
||||||
}
|
|
||||||
while (auto tc = builder.try_find_literal(form.tool_start)) {
|
|
||||||
if (!all_space(tc->prelude)) {
|
|
||||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
|
||||||
gbnf_format_literal(form.tool_start).c_str(),
|
|
||||||
gbnf_format_literal(tc->prelude).c_str()
|
|
||||||
);
|
|
||||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find tool name
|
|
||||||
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
|
|
||||||
if (!func_name) {
|
|
||||||
auto [sz, tc] = try_find_tool_end();
|
|
||||||
func_name = tc;
|
|
||||||
}
|
|
||||||
if (!func_name) {
|
|
||||||
// Partial tool name not supported
|
|
||||||
throw common_chat_msg_partial_exception("incomplete tool_call");
|
|
||||||
}
|
|
||||||
// If the model generate multiple tool call and the first tool call has no argument
|
|
||||||
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
|
|
||||||
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
|
|
||||||
auto [sz, tc] = try_find_tool_end();
|
|
||||||
func_name = tc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse tool name
|
|
||||||
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
|
|
||||||
std::string function_name = string_strip(func_name->prelude);
|
|
||||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
|
||||||
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
|
||||||
if (string_starts_with(function_name, "functions.")) {
|
|
||||||
static const std::regex re(":\\d+$");
|
|
||||||
if (std::regex_search(function_name, re)) {
|
|
||||||
function_name = function_name.substr(10, function_name.rfind(":") - 10);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Argument JSON
|
|
||||||
json arguments = json::object();
|
|
||||||
|
|
||||||
// Helper to generate a partial argument JSON
|
|
||||||
const auto gen_partial_args = [&](auto set_partial_arg) {
|
|
||||||
gen_partial_json(set_partial_arg, arguments, builder, function_name);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parse all arg_key/arg_value pairs
|
|
||||||
while (auto tc = builder.try_find_literal(form.key_start)) {
|
|
||||||
if (!all_space(tc->prelude)) {
|
|
||||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
|
||||||
gbnf_format_literal(form.key_start).c_str(),
|
|
||||||
gbnf_format_literal(tc->prelude).c_str()
|
|
||||||
);
|
|
||||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
|
|
||||||
auto tool_call_arg = arguments.dump();
|
|
||||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
|
||||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
|
||||||
}
|
|
||||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
|
||||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse arg_key
|
|
||||||
auto key_res = builder.try_find_literal(form.key_val_sep);
|
|
||||||
if (!key_res) {
|
|
||||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
|
|
||||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
|
|
||||||
}
|
|
||||||
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
|
|
||||||
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
|
|
||||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
|
|
||||||
}
|
|
||||||
auto &key = key_res->prelude;
|
|
||||||
recovery = false;
|
|
||||||
|
|
||||||
// Parse arg_value
|
|
||||||
if (form.key_val_sep2) {
|
|
||||||
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
|
|
||||||
if (!all_space(tc->prelude)) {
|
|
||||||
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
|
|
||||||
gbnf_format_literal(tc->prelude).c_str(),
|
|
||||||
gbnf_format_literal(form.key_val_sep).c_str(),
|
|
||||||
gbnf_format_literal(*form.key_val_sep2).c_str()
|
|
||||||
);
|
|
||||||
return return_error(builder, start_pos, false);
|
|
||||||
}
|
|
||||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
|
|
||||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
|
||||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
|
||||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto val_start = builder.pos();
|
|
||||||
|
|
||||||
// Test if arg_val is a partial JSON
|
|
||||||
std::optional<common_json> value_json = std::nullopt;
|
|
||||||
if (!form.raw_argval || !*form.raw_argval) {
|
|
||||||
try { value_json = builder.try_consume_json(); }
|
|
||||||
catch (const std::runtime_error&) { builder.move_to(val_start); }
|
|
||||||
// TODO: Delete this when json_partial adds top-level support for null/true/false
|
|
||||||
if (builder.pos() == val_start) {
|
|
||||||
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
|
|
||||||
builder.consume_spaces();
|
|
||||||
std::string_view sv = utf8_truncate_safe_view(builder.input());
|
|
||||||
sv.remove_prefix(builder.pos());
|
|
||||||
std::string rest = "a";
|
|
||||||
if (sv.size() < 6) rest = sv;
|
|
||||||
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
|
|
||||||
value_json = {123, {"123", "123"}};
|
|
||||||
builder.consume_rest();
|
|
||||||
} else {
|
|
||||||
builder.move_to(val_start);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it is a JSON and followed by </arg_value>, parse as json
|
|
||||||
// cannot support streaming because it may be a plain text starting with JSON
|
|
||||||
if (value_json) {
|
|
||||||
auto json_end = builder.pos();
|
|
||||||
builder.consume_spaces();
|
|
||||||
if (builder.pos() == builder.input().size()) {
|
|
||||||
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
|
|
||||||
arguments[key] = value_json->json;
|
|
||||||
auto json_str = arguments.dump();
|
|
||||||
if (!value_json->healing_marker.json_dump_marker.empty()) {
|
|
||||||
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
|
|
||||||
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(json_str.back() == '}');
|
|
||||||
json_str.resize(json_str.size() - 1);
|
|
||||||
}
|
|
||||||
builder.add_tool_call(function_name, "", json_str);
|
|
||||||
} else {
|
|
||||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
|
||||||
}
|
|
||||||
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
|
|
||||||
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
|
|
||||||
}
|
|
||||||
builder.move_to(json_end);
|
|
||||||
auto [val_end_size, tc] = try_find_val_end();
|
|
||||||
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
|
|
||||||
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
|
|
||||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
|
||||||
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
|
|
||||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
|
|
||||||
} else arguments[key] = value_json->json;
|
|
||||||
} else builder.move_to(val_start);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If not, parse as plain text
|
|
||||||
if (val_start == builder.pos()) {
|
|
||||||
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
|
|
||||||
auto &value_str = value_plain->prelude;
|
|
||||||
if (form.trim_raw_argval) value_str = string_strip(value_str);
|
|
||||||
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
|
|
||||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
|
|
||||||
throw common_chat_msg_partial_exception(
|
|
||||||
"Expected " + gbnf_format_literal(form.val_end) +
|
|
||||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
|
||||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
arguments[key] = value_str;
|
|
||||||
} else {
|
|
||||||
if (form.trim_raw_argval) {
|
|
||||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
|
|
||||||
} else {
|
|
||||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
|
|
||||||
}
|
|
||||||
throw common_chat_msg_partial_exception(
|
|
||||||
"Expected " + gbnf_format_literal(form.val_end) +
|
|
||||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
|
||||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consume closing tag
|
|
||||||
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
|
|
||||||
if (!all_space(tc->prelude)) {
|
|
||||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
|
||||||
gbnf_format_literal(form.tool_end).c_str(),
|
|
||||||
gbnf_format_literal(tc->prelude).c_str()
|
|
||||||
);
|
|
||||||
return return_error(builder, start_pos, recovery);
|
|
||||||
}
|
|
||||||
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
|
|
||||||
// Add the parsed tool call
|
|
||||||
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
|
|
||||||
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
|
|
||||||
}
|
|
||||||
recovery = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tool_call_arg = arguments.dump();
|
|
||||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
|
||||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
|
||||||
}
|
|
||||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
|
||||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
|
|
||||||
}
|
|
||||||
if (auto tc = builder.try_find_literal(form.scope_end)) {
|
|
||||||
if (!all_space(tc->prelude)) {
|
|
||||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
|
||||||
gbnf_format_literal(form.scope_end).c_str(),
|
|
||||||
gbnf_format_literal(tc->prelude).c_str()
|
|
||||||
);
|
|
||||||
return return_error(builder, start_pos, recovery);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (all_space(form.scope_end)) return true;
|
|
||||||
builder.consume_spaces();
|
|
||||||
if (builder.pos() == builder.input().size())
|
|
||||||
throw common_chat_msg_partial_exception("incomplete tool calls");
|
|
||||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
|
||||||
gbnf_format_literal(form.scope_end).c_str(),
|
|
||||||
gbnf_format_literal(builder.consume_rest()).c_str()
|
|
||||||
);
|
|
||||||
return return_error(builder, start_pos, recovery);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
|
||||||
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
|
|
||||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
|
||||||
*/
|
|
||||||
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
|
|
||||||
auto pos = pos_;
|
|
||||||
auto tsize = result_.tool_calls.size();
|
|
||||||
try { return parse_xml_tool_calls(*this, form); }
|
|
||||||
catch (const xml_toolcall_syntax_exception&) {}
|
|
||||||
move_to(pos);
|
|
||||||
result_.tool_calls.resize(tsize);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse content uses reasoning and XML-Style tool call
|
|
||||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
|
||||||
*/
|
|
||||||
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
|
|
||||||
constexpr auto rstrip = [](std::string &s) {
|
|
||||||
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
|
|
||||||
};
|
|
||||||
// Erase substring from l to r, along with additional spaces nearby
|
|
||||||
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
|
|
||||||
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
|
|
||||||
++l;
|
|
||||||
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
|
|
||||||
if (l < r) str[l] = '\n';
|
|
||||||
if (l + 1 < r) str[l + 1] = '\n';
|
|
||||||
if (l != 0) l += 2;
|
|
||||||
str.erase(l, r - l);
|
|
||||||
return l;
|
|
||||||
};
|
|
||||||
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
|
|
||||||
auto best_match = content.size();
|
|
||||||
for (auto pattern: list) {
|
|
||||||
if (pattern.size() == 0) continue;
|
|
||||||
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
|
|
||||||
auto match_len = content.size() - match_idx;
|
|
||||||
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
|
|
||||||
best_match = match_idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (content.size() > best_match) {
|
|
||||||
content.erase(best_match);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
|
|
||||||
return trim_suffix(content, {
|
|
||||||
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
|
|
||||||
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
|
|
||||||
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
|
|
||||||
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
|
|
||||||
form.scope_end
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// Trim leading spaces without affecting keyword matching
|
|
||||||
static const common_regex spaces_regex("\\s*");
|
|
||||||
{
|
|
||||||
auto tc = builder.consume_regex(spaces_regex);
|
|
||||||
auto spaces = builder.str(tc.groups[0]);
|
|
||||||
auto s1 = spaces.size();
|
|
||||||
trim_potential_partial_word(spaces);
|
|
||||||
auto s2 = spaces.size();
|
|
||||||
builder.move_to(builder.pos() - (s1 - s2));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse content
|
|
||||||
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
|
|
||||||
std::string unclosed_reasoning_content("");
|
|
||||||
for (;;) {
|
|
||||||
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
|
|
||||||
std::string content;
|
|
||||||
std::string tool_call_start;
|
|
||||||
|
|
||||||
if (tc) {
|
|
||||||
content = std::move(tc->prelude);
|
|
||||||
tool_call_start = builder.str(tc->groups[0]);
|
|
||||||
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
|
|
||||||
} else {
|
|
||||||
content = builder.consume_rest();
|
|
||||||
utf8_truncate_safe_resize(content);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle unclosed think block
|
|
||||||
if (reasoning_unclosed) {
|
|
||||||
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
|
||||||
unclosed_reasoning_content += content;
|
|
||||||
if (!(form.allow_toolcall_in_think && tc)) {
|
|
||||||
unclosed_reasoning_content += tool_call_start;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
reasoning_unclosed = false;
|
|
||||||
std::string reasoning_content;
|
|
||||||
if (pos == std::string::npos) {
|
|
||||||
reasoning_content = std::move(content);
|
|
||||||
} else {
|
|
||||||
reasoning_content = content.substr(0, pos);
|
|
||||||
content.erase(0, pos + end_think.size());
|
|
||||||
}
|
|
||||||
if (builder.pos() == builder.input().size() && all_space(content)) {
|
|
||||||
rstrip(reasoning_content);
|
|
||||||
trim_potential_partial_word(reasoning_content);
|
|
||||||
rstrip(reasoning_content);
|
|
||||||
if (reasoning_content.empty()) {
|
|
||||||
rstrip(unclosed_reasoning_content);
|
|
||||||
trim_potential_partial_word(unclosed_reasoning_content);
|
|
||||||
rstrip(unclosed_reasoning_content);
|
|
||||||
if (unclosed_reasoning_content.empty()) continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
|
||||||
builder.add_content(start_think);
|
|
||||||
builder.add_content(unclosed_reasoning_content);
|
|
||||||
builder.add_content(reasoning_content);
|
|
||||||
if (builder.pos() != builder.input().size() || !all_space(content))
|
|
||||||
builder.add_content(end_think);
|
|
||||||
} else {
|
|
||||||
builder.add_reasoning_content(unclosed_reasoning_content);
|
|
||||||
builder.add_reasoning_content(reasoning_content);
|
|
||||||
}
|
|
||||||
unclosed_reasoning_content.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle multiple think block
|
|
||||||
bool toolcall_in_think = false;
|
|
||||||
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
|
|
||||||
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
|
|
||||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
|
||||||
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
|
|
||||||
builder.add_reasoning_content(reasoning_content);
|
|
||||||
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
|
|
||||||
} else {
|
|
||||||
think_start = think_end + end_think.size() - 1;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// This <tool_call> start is in thinking block, skip this tool call
|
|
||||||
// This <tool_call> start is in thinking block
|
|
||||||
if (form.allow_toolcall_in_think) {
|
|
||||||
unclosed_reasoning_content = content.substr(think_start + start_think.size());
|
|
||||||
} else {
|
|
||||||
unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start;
|
|
||||||
}
|
|
||||||
reasoning_unclosed = true;
|
|
||||||
content.resize(think_start);
|
|
||||||
toolcall_in_think = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
|
||||||
rstrip(content);
|
|
||||||
// Handle unclosed </think> token from content: delete all </think> token
|
|
||||||
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
|
|
||||||
while (pos != std::string::npos) {
|
|
||||||
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
|
|
||||||
pos = content.rfind(end_think, pos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Strip if needed
|
|
||||||
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
|
|
||||||
content = string_strip(content);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove potential partial suffix
|
|
||||||
if (builder.pos() == builder.input().size()) {
|
|
||||||
if (unclosed_reasoning_content.empty()) {
|
|
||||||
rstrip(content);
|
|
||||||
trim_potential_partial_word(content);
|
|
||||||
rstrip(content);
|
|
||||||
} else {
|
|
||||||
rstrip(unclosed_reasoning_content);
|
|
||||||
trim_potential_partial_word(unclosed_reasoning_content);
|
|
||||||
rstrip(unclosed_reasoning_content);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// consume unclosed_reasoning_content if allow_toolcall_in_think is set
|
|
||||||
if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) {
|
|
||||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
|
||||||
builder.add_reasoning_content(unclosed_reasoning_content);
|
|
||||||
} else {
|
|
||||||
if (content.empty()) {
|
|
||||||
content = start_think + unclosed_reasoning_content;
|
|
||||||
} else {
|
|
||||||
content += "\n\n" + start_think;
|
|
||||||
content += unclosed_reasoning_content;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unclosed_reasoning_content.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add content
|
|
||||||
if (!content.empty()) {
|
|
||||||
// If there are multiple content blocks
|
|
||||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
|
||||||
builder.add_content("\n\n");
|
|
||||||
}
|
|
||||||
builder.add_content(content);
|
|
||||||
}
|
|
||||||
|
|
||||||
// This <tool_call> start is in thinking block and toolcall_in_think not set, skip this tool call
|
|
||||||
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// There is no tool call and all content is parsed
|
|
||||||
if (!tc) {
|
|
||||||
GGML_ASSERT(builder.pos() == builder.input().size());
|
|
||||||
GGML_ASSERT(unclosed_reasoning_content.empty());
|
|
||||||
if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.move_to(tc->groups[0].begin);
|
|
||||||
if (builder.try_consume_xml_tool_calls(form)) {
|
|
||||||
auto end_of_tool = builder.pos();
|
|
||||||
builder.consume_spaces();
|
|
||||||
if (builder.pos() != builder.input().size()) {
|
|
||||||
builder.move_to(end_of_tool);
|
|
||||||
if (!builder.result().content.empty()) {
|
|
||||||
builder.add_content("\n\n");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
static const common_regex next_char_regex(".");
|
|
||||||
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
|
|
||||||
rstrip(c);
|
|
||||||
builder.add_content(c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse content uses reasoning and XML-Style tool call
|
|
||||||
*/
|
|
||||||
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
|
||||||
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
|
||||||
}
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "chat.h"
|
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
|
|
||||||
#include <optional>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
|
|
||||||
// Sample config:
|
|
||||||
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
|
|
||||||
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
|
|
||||||
struct xml_tool_call_format {
|
|
||||||
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
|
|
||||||
std::string tool_start; // <invoke name=\" // <tool_call>
|
|
||||||
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
|
|
||||||
std::string key_start; // <parameter name=\" // <arg_key>
|
|
||||||
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
|
|
||||||
std::string val_end; // </parameter>\n // </arg_value>\n
|
|
||||||
std::string tool_end; // </invoke>\n // </tool_call>\n
|
|
||||||
std::string scope_end; // </minimax:tool_call> // // can be empty
|
|
||||||
// Set this if there can be dynamic spaces inside key_val_sep.
|
|
||||||
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
|
|
||||||
std::optional<std::string> key_val_sep2 = std::nullopt;
|
|
||||||
// Set true if argval should only be raw string. e.g. Hello "world" hi
|
|
||||||
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
|
|
||||||
// Defaults to std::nullopt, both will be allowed.
|
|
||||||
std::optional<bool> raw_argval = std::nullopt;
|
|
||||||
std::optional<std::string> last_val_end = std::nullopt;
|
|
||||||
std::optional<std::string> last_tool_end = std::nullopt;
|
|
||||||
bool trim_raw_argval = false;
|
|
||||||
bool allow_toolcall_in_think = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
|
||||||
std::string make_gbnf_excluding(std::vector<std::string> forbids);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build grammar for xml-style tool call
|
|
||||||
* form.scope_start and form.scope_end can be empty.
|
|
||||||
* Requires data.format for model-specific hacks.
|
|
||||||
*/
|
|
||||||
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,133 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "chat.h"
|
|
||||||
#include "chat-parser-xml-toolcall.h"
|
|
||||||
#include "json-partial.h"
|
|
||||||
#include "regex-partial.h"
|
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
|
|
||||||
#include <optional>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
class common_chat_msg_partial_exception : public std::runtime_error {
|
|
||||||
public:
|
|
||||||
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
class common_chat_msg_parser {
|
|
||||||
std::string input_;
|
|
||||||
bool is_partial_;
|
|
||||||
common_chat_syntax syntax_;
|
|
||||||
std::string healing_marker_;
|
|
||||||
|
|
||||||
size_t pos_ = 0;
|
|
||||||
common_chat_msg result_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
|
||||||
const std::string & input() const { return input_; }
|
|
||||||
size_t pos() const { return pos_; }
|
|
||||||
const std::string & healing_marker() const { return healing_marker_; }
|
|
||||||
const bool & is_partial() const { return is_partial_; }
|
|
||||||
const common_chat_msg & result() const { return result_; }
|
|
||||||
const common_chat_syntax & syntax() const { return syntax_; }
|
|
||||||
|
|
||||||
void move_to(size_t pos) {
|
|
||||||
if (pos > input_.size()) {
|
|
||||||
throw std::runtime_error("Invalid position!");
|
|
||||||
}
|
|
||||||
pos_ = pos;
|
|
||||||
}
|
|
||||||
void move_back(size_t n) {
|
|
||||||
if (pos_ < n) {
|
|
||||||
throw std::runtime_error("Can't move back that far!");
|
|
||||||
}
|
|
||||||
pos_ -= n;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the substring of the input at the given range
|
|
||||||
std::string str(const common_string_range & rng) const;
|
|
||||||
|
|
||||||
// Appends to the result.content field
|
|
||||||
void add_content(const std::string & content);
|
|
||||||
|
|
||||||
// Appends to the result.reasoning_content field
|
|
||||||
void add_reasoning_content(const std::string & reasoning_content);
|
|
||||||
|
|
||||||
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
|
|
||||||
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
|
||||||
|
|
||||||
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
|
|
||||||
bool add_tool_call(const nlohmann::ordered_json & tool_call);
|
|
||||||
|
|
||||||
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
|
||||||
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
|
||||||
|
|
||||||
// Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
|
|
||||||
bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call);
|
|
||||||
|
|
||||||
void finish();
|
|
||||||
|
|
||||||
bool consume_spaces();
|
|
||||||
|
|
||||||
void consume_literal(const std::string & literal);
|
|
||||||
|
|
||||||
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
|
||||||
|
|
||||||
std::string consume_rest();
|
|
||||||
|
|
||||||
struct find_regex_result {
|
|
||||||
std::string prelude;
|
|
||||||
std::vector<common_string_range> groups;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
|
|
||||||
|
|
||||||
bool try_consume_literal(const std::string & literal);
|
|
||||||
|
|
||||||
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
|
||||||
|
|
||||||
find_regex_result consume_regex(const common_regex & regex);
|
|
||||||
|
|
||||||
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
|
||||||
|
|
||||||
std::optional<common_json> try_consume_json();
|
|
||||||
common_json consume_json();
|
|
||||||
|
|
||||||
struct consume_json_result {
|
|
||||||
nlohmann::ordered_json value;
|
|
||||||
bool is_partial;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
|
|
||||||
|
|
||||||
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
|
|
||||||
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
|
|
||||||
|
|
||||||
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
|
|
||||||
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
|
|
||||||
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
|
|
||||||
*/
|
|
||||||
consume_json_result consume_json_with_dumped_args(
|
|
||||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
|
||||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
|
||||||
);
|
|
||||||
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
|
||||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
|
||||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
|
||||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
|
||||||
*/
|
|
||||||
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
|
|
||||||
|
|
||||||
// Parse content uses reasoning and XML-Style tool call
|
|
||||||
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
|
|
||||||
|
|
||||||
void clear_tools();
|
|
||||||
};
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
#include "chat-peg-parser.h"
|
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
|
|
||||||
using json = nlohmann::json;
|
|
||||||
|
|
||||||
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
|
|
||||||
int count = 0;
|
|
||||||
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
|
|
||||||
if (max != -1 && count <= max) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
sv.remove_suffix(1);
|
|
||||||
count++;
|
|
||||||
}
|
|
||||||
return sv;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
|
||||||
arena.visit(result, [this](const common_peg_ast_node & node) {
|
|
||||||
map(node);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
|
||||||
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
|
|
||||||
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
|
|
||||||
|
|
||||||
if (is_reasoning) {
|
|
||||||
result.reasoning_content = std::string(trim_trailing_space(node.text));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_content) {
|
|
||||||
result.content = std::string(trim_trailing_space(node.text));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) {
|
|
||||||
common_chat_peg_mapper::map(node);
|
|
||||||
|
|
||||||
bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN;
|
|
||||||
bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME;
|
|
||||||
bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID;
|
|
||||||
bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS;
|
|
||||||
|
|
||||||
if (is_tool_open) {
|
|
||||||
result.tool_calls.emplace_back();
|
|
||||||
current_tool = &result.tool_calls.back();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_tool_id && current_tool) {
|
|
||||||
current_tool->id = std::string(trim_trailing_space(node.text));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_tool_name && current_tool) {
|
|
||||||
current_tool->name = std::string(trim_trailing_space(node.text));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_tool_args && current_tool) {
|
|
||||||
current_tool->arguments = std::string(trim_trailing_space(node.text));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
|
|
||||||
common_chat_peg_mapper::map(node);
|
|
||||||
|
|
||||||
bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN;
|
|
||||||
bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME;
|
|
||||||
bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE;
|
|
||||||
bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN;
|
|
||||||
bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE;
|
|
||||||
bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME;
|
|
||||||
bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE;
|
|
||||||
bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE;
|
|
||||||
|
|
||||||
if (is_tool_open) {
|
|
||||||
result.tool_calls.emplace_back();
|
|
||||||
current_tool = &result.tool_calls.back();
|
|
||||||
arg_count = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_tool_name) {
|
|
||||||
current_tool->name = std::string(node.text);
|
|
||||||
current_tool->arguments = "{";
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_arg_open) {
|
|
||||||
needs_closing_quote = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_arg_name && current_tool) {
|
|
||||||
if (arg_count > 0) {
|
|
||||||
current_tool->arguments += ",";
|
|
||||||
}
|
|
||||||
current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":";
|
|
||||||
++arg_count;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_arg_string && current_tool) {
|
|
||||||
// Serialize to JSON, but exclude the end quote
|
|
||||||
std::string dumped = json(trim_trailing_space(node.text)).dump();
|
|
||||||
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
|
|
||||||
needs_closing_quote = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_arg_close && current_tool) {
|
|
||||||
if (needs_closing_quote) {
|
|
||||||
current_tool->arguments += "\"";
|
|
||||||
needs_closing_quote = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_arg_json && current_tool) {
|
|
||||||
current_tool->arguments += std::string(trim_trailing_space(node.text));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_tool_close && current_tool) {
|
|
||||||
if (needs_closing_quote) {
|
|
||||||
current_tool->arguments += "\"";
|
|
||||||
needs_closing_quote = false;
|
|
||||||
}
|
|
||||||
current_tool->arguments += "}";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "chat.h"
|
|
||||||
#include "peg-parser.h"
|
|
||||||
|
|
||||||
class common_chat_peg_builder : public common_peg_parser_builder {
|
|
||||||
public:
|
|
||||||
static constexpr const char * REASONING_BLOCK = "reasoning-block";
|
|
||||||
static constexpr const char * REASONING = "reasoning";
|
|
||||||
static constexpr const char * CONTENT = "content";
|
|
||||||
|
|
||||||
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
|
|
||||||
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
|
|
||||||
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
|
|
||||||
};
|
|
||||||
|
|
||||||
inline common_peg_arena build_chat_peg_parser(const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
|
|
||||||
common_chat_peg_builder builder;
|
|
||||||
builder.set_root(fn(builder));
|
|
||||||
return builder.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
class common_chat_peg_mapper {
|
|
||||||
public:
|
|
||||||
common_chat_msg & result;
|
|
||||||
|
|
||||||
common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {}
|
|
||||||
|
|
||||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
|
||||||
virtual void map(const common_peg_ast_node & node);
|
|
||||||
};
|
|
||||||
|
|
||||||
class common_chat_peg_native_builder : public common_chat_peg_builder {
|
|
||||||
public:
|
|
||||||
static constexpr const char * TOOL = "tool";
|
|
||||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
|
||||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
|
||||||
static constexpr const char * TOOL_ID = "tool-id";
|
|
||||||
static constexpr const char * TOOL_NAME = "tool-name";
|
|
||||||
static constexpr const char * TOOL_ARGS = "tool-args";
|
|
||||||
|
|
||||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
|
||||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
|
||||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
|
||||||
common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); }
|
|
||||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
|
||||||
common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); }
|
|
||||||
};
|
|
||||||
|
|
||||||
class common_chat_peg_native_mapper : public common_chat_peg_mapper {
|
|
||||||
common_chat_tool_call * current_tool;
|
|
||||||
|
|
||||||
public:
|
|
||||||
common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
|
||||||
|
|
||||||
void map(const common_peg_ast_node & node) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
inline common_peg_arena build_chat_peg_native_parser(const std::function<common_peg_parser(common_chat_peg_native_builder & builder)> & fn) {
|
|
||||||
common_chat_peg_native_builder builder;
|
|
||||||
builder.set_root(fn(builder));
|
|
||||||
return builder.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
class common_chat_peg_constructed_builder : public common_chat_peg_builder {
|
|
||||||
public:
|
|
||||||
static constexpr const char * TOOL = "tool";
|
|
||||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
|
||||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
|
||||||
static constexpr const char * TOOL_NAME = "tool-name";
|
|
||||||
static constexpr const char * TOOL_ARG = "tool-arg";
|
|
||||||
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
|
|
||||||
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
|
|
||||||
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
|
|
||||||
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value";
|
|
||||||
static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value";
|
|
||||||
|
|
||||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
|
||||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
|
||||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
|
||||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
|
||||||
common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); }
|
|
||||||
common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); }
|
|
||||||
common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); }
|
|
||||||
common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); }
|
|
||||||
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
|
|
||||||
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); }
|
|
||||||
};
|
|
||||||
|
|
||||||
class common_chat_peg_constructed_mapper : public common_chat_peg_mapper {
|
|
||||||
common_chat_tool_call * current_tool;
|
|
||||||
int arg_count = 0;
|
|
||||||
bool needs_closing_quote = false;
|
|
||||||
|
|
||||||
public:
|
|
||||||
common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
|
||||||
|
|
||||||
void map(const common_peg_ast_node & node) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
inline common_peg_arena build_chat_peg_constructed_parser(const std::function<common_peg_parser(common_chat_peg_constructed_builder & builder)> & fn) {
|
|
||||||
common_chat_peg_constructed_builder builder;
|
|
||||||
builder.set_root(fn(builder));
|
|
||||||
return builder.build();
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,234 +0,0 @@
|
|||||||
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "peg-parser.h"
|
|
||||||
#include <functional>
|
|
||||||
#include <chrono>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
struct common_chat_templates;
|
|
||||||
|
|
||||||
struct common_chat_tool_call {
|
|
||||||
std::string name;
|
|
||||||
std::string arguments;
|
|
||||||
std::string id;
|
|
||||||
|
|
||||||
bool operator==(const common_chat_tool_call & other) const {
|
|
||||||
return name == other.name && arguments == other.arguments && id == other.id;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_msg_content_part {
|
|
||||||
std::string type;
|
|
||||||
std::string text;
|
|
||||||
|
|
||||||
bool operator==(const common_chat_msg_content_part & other) const {
|
|
||||||
return type == other.type && text == other.text;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_msg {
|
|
||||||
std::string role;
|
|
||||||
std::string content;
|
|
||||||
std::vector<common_chat_msg_content_part> content_parts;
|
|
||||||
std::vector<common_chat_tool_call> tool_calls;
|
|
||||||
std::string reasoning_content;
|
|
||||||
std::string tool_name;
|
|
||||||
std::string tool_call_id;
|
|
||||||
|
|
||||||
template <class T> T to_json_oaicompat() const;
|
|
||||||
|
|
||||||
bool empty() const {
|
|
||||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
|
||||||
}
|
|
||||||
void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
|
||||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
|
||||||
if (ids_cache.size() <= i) {
|
|
||||||
auto id = tool_calls[i].id;
|
|
||||||
if (id.empty()) {
|
|
||||||
id = gen_tool_call_id();
|
|
||||||
}
|
|
||||||
ids_cache.push_back(id);
|
|
||||||
}
|
|
||||||
tool_calls[i].id = ids_cache[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool operator==(const common_chat_msg & other) const {
|
|
||||||
return role == other.role
|
|
||||||
&& content == other.content
|
|
||||||
&& content_parts == other.content_parts
|
|
||||||
&& tool_calls == other.tool_calls
|
|
||||||
&& reasoning_content == other.reasoning_content
|
|
||||||
&& tool_name == other.tool_name
|
|
||||||
&& tool_call_id == other.tool_call_id;
|
|
||||||
}
|
|
||||||
bool operator!=(const common_chat_msg & other) const {
|
|
||||||
return !(*this == other);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_msg_diff {
|
|
||||||
std::string reasoning_content_delta;
|
|
||||||
std::string content_delta;
|
|
||||||
size_t tool_call_index = std::string::npos;
|
|
||||||
common_chat_tool_call tool_call_delta;
|
|
||||||
|
|
||||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
|
|
||||||
|
|
||||||
bool operator==(const common_chat_msg_diff & other) const {
|
|
||||||
return content_delta == other.content_delta
|
|
||||||
&& tool_call_index == other.tool_call_index
|
|
||||||
&& tool_call_delta == other.tool_call_delta;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_tool {
|
|
||||||
std::string name;
|
|
||||||
std::string description;
|
|
||||||
std::string parameters;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum common_chat_tool_choice {
|
|
||||||
COMMON_CHAT_TOOL_CHOICE_AUTO,
|
|
||||||
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
|
|
||||||
COMMON_CHAT_TOOL_CHOICE_NONE,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum common_chat_format {
|
|
||||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
|
||||||
COMMON_CHAT_FORMAT_GENERIC,
|
|
||||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
|
||||||
COMMON_CHAT_FORMAT_MAGISTRAL,
|
|
||||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
|
||||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
||||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
|
||||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
|
||||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
|
||||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
||||||
COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
|
||||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
||||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
|
||||||
COMMON_CHAT_FORMAT_GRANITE,
|
|
||||||
COMMON_CHAT_FORMAT_GPT_OSS,
|
|
||||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
|
||||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
|
||||||
COMMON_CHAT_FORMAT_APERTUS,
|
|
||||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
|
||||||
COMMON_CHAT_FORMAT_GLM_4_5,
|
|
||||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
|
||||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
|
||||||
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
|
||||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
|
||||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
|
||||||
COMMON_CHAT_FORMAT_SOLAR_OPEN,
|
|
||||||
|
|
||||||
// These are intended to be parsed by the PEG parser
|
|
||||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
|
||||||
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
|
||||||
COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
|
|
||||||
|
|
||||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_templates_inputs {
|
|
||||||
std::vector<common_chat_msg> messages;
|
|
||||||
std::string grammar;
|
|
||||||
std::string json_schema;
|
|
||||||
bool add_generation_prompt = true;
|
|
||||||
bool use_jinja = true;
|
|
||||||
// Parameters below only supported when use_jinja is true
|
|
||||||
std::vector<common_chat_tool> tools;
|
|
||||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
|
||||||
bool parallel_tool_calls = false;
|
|
||||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
|
||||||
bool enable_thinking = true;
|
|
||||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
|
||||||
std::map<std::string, std::string> chat_template_kwargs;
|
|
||||||
bool add_bos = false;
|
|
||||||
bool add_eos = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_params {
|
|
||||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
||||||
std::string prompt;
|
|
||||||
std::string grammar;
|
|
||||||
bool grammar_lazy = false;
|
|
||||||
bool thinking_forced_open = false;
|
|
||||||
std::vector<common_grammar_trigger> grammar_triggers;
|
|
||||||
std::vector<std::string> preserved_tokens;
|
|
||||||
std::vector<std::string> additional_stops;
|
|
||||||
std::string parser;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_syntax {
|
|
||||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
||||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
|
||||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
|
||||||
bool reasoning_in_content = false;
|
|
||||||
bool thinking_forced_open = false;
|
|
||||||
bool parse_tool_calls = true;
|
|
||||||
common_peg_arena parser = {};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
|
||||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
|
||||||
|
|
||||||
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
|
||||||
|
|
||||||
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
|
||||||
|
|
||||||
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
|
||||||
|
|
||||||
common_chat_templates_ptr common_chat_templates_init(
|
|
||||||
const struct llama_model * model,
|
|
||||||
const std::string & chat_template_override,
|
|
||||||
const std::string & bos_token_override = "",
|
|
||||||
const std::string & eos_token_override = "");
|
|
||||||
|
|
||||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
|
||||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
|
||||||
|
|
||||||
|
|
||||||
struct common_chat_params common_chat_templates_apply(
|
|
||||||
const struct common_chat_templates * tmpls,
|
|
||||||
const struct common_chat_templates_inputs & inputs);
|
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
|
||||||
std::string common_chat_format_single(
|
|
||||||
const struct common_chat_templates * tmpls,
|
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
|
||||||
const common_chat_msg & new_msg,
|
|
||||||
bool add_ass,
|
|
||||||
bool use_jinja);
|
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
|
||||||
std::string common_chat_format_example(
|
|
||||||
const struct common_chat_templates * tmpls,
|
|
||||||
bool use_jinja,
|
|
||||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
|
||||||
|
|
||||||
const char* common_chat_format_name(common_chat_format format);
|
|
||||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
|
||||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
|
||||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
|
||||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
|
||||||
|
|
||||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
|
||||||
|
|
||||||
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
|
|
||||||
|
|
||||||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
|
||||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
|
||||||
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
|
||||||
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
|
||||||
|
|
||||||
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
|
||||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
|
||||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
|
||||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
|
||||||
|
|
||||||
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,858 +0,0 @@
|
|||||||
// Various helper functions and utilities
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml-opt.h"
|
|
||||||
#include "llama-cpp.h"
|
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
|
||||||
#define _WIN32_WINNT 0x0A00
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
|
||||||
#define DIRECTORY_SEPARATOR '\\'
|
|
||||||
#else
|
|
||||||
#define DIRECTORY_SEPARATOR '/'
|
|
||||||
#endif // _WIN32
|
|
||||||
|
|
||||||
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
|
||||||
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
|
||||||
|
|
||||||
#define print_build_info() do { \
|
|
||||||
fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \
|
|
||||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
|
||||||
} while(0)
|
|
||||||
|
|
||||||
struct common_time_meas {
|
|
||||||
common_time_meas(int64_t & t_acc, bool disable = false);
|
|
||||||
~common_time_meas();
|
|
||||||
|
|
||||||
const int64_t t_start_us;
|
|
||||||
|
|
||||||
int64_t & t_acc;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_adapter_lora_info {
|
|
||||||
std::string path;
|
|
||||||
float scale;
|
|
||||||
|
|
||||||
std::string task_name;
|
|
||||||
std::string prompt_prefix;
|
|
||||||
|
|
||||||
struct llama_adapter_lora * ptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
using llama_tokens = std::vector<llama_token>;
|
|
||||||
|
|
||||||
// build info
|
|
||||||
extern int LLAMA_BUILD_NUMBER;
|
|
||||||
extern const char * LLAMA_COMMIT;
|
|
||||||
extern const char * LLAMA_COMPILER;
|
|
||||||
extern const char * LLAMA_BUILD_TARGET;
|
|
||||||
|
|
||||||
struct common_control_vector_load_info;
|
|
||||||
|
|
||||||
//
|
|
||||||
// CPU utils
|
|
||||||
//
|
|
||||||
|
|
||||||
struct cpu_params {
|
|
||||||
int n_threads = -1;
|
|
||||||
bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
|
|
||||||
bool mask_valid = false; // Default: any CPU
|
|
||||||
enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
|
|
||||||
bool strict_cpu = false; // Use strict CPU placement
|
|
||||||
uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
|
|
||||||
};
|
|
||||||
|
|
||||||
int32_t cpu_get_num_physical_cores();
|
|
||||||
int32_t cpu_get_num_math();
|
|
||||||
|
|
||||||
//
|
|
||||||
// Common params
|
|
||||||
//
|
|
||||||
|
|
||||||
enum llama_example {
|
|
||||||
LLAMA_EXAMPLE_DEBUG,
|
|
||||||
LLAMA_EXAMPLE_COMMON,
|
|
||||||
LLAMA_EXAMPLE_SPECULATIVE,
|
|
||||||
LLAMA_EXAMPLE_COMPLETION,
|
|
||||||
LLAMA_EXAMPLE_CLI,
|
|
||||||
LLAMA_EXAMPLE_EMBEDDING,
|
|
||||||
LLAMA_EXAMPLE_PERPLEXITY,
|
|
||||||
LLAMA_EXAMPLE_RETRIEVAL,
|
|
||||||
LLAMA_EXAMPLE_PASSKEY,
|
|
||||||
LLAMA_EXAMPLE_IMATRIX,
|
|
||||||
LLAMA_EXAMPLE_BENCH,
|
|
||||||
LLAMA_EXAMPLE_SERVER,
|
|
||||||
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
|
||||||
LLAMA_EXAMPLE_EXPORT_LORA,
|
|
||||||
LLAMA_EXAMPLE_MTMD,
|
|
||||||
LLAMA_EXAMPLE_LOOKUP,
|
|
||||||
LLAMA_EXAMPLE_PARALLEL,
|
|
||||||
LLAMA_EXAMPLE_TTS,
|
|
||||||
LLAMA_EXAMPLE_DIFFUSION,
|
|
||||||
LLAMA_EXAMPLE_FINETUNE,
|
|
||||||
LLAMA_EXAMPLE_FIT_PARAMS,
|
|
||||||
|
|
||||||
LLAMA_EXAMPLE_COUNT,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum common_sampler_type {
|
|
||||||
COMMON_SAMPLER_TYPE_NONE = 0,
|
|
||||||
COMMON_SAMPLER_TYPE_DRY = 1,
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_K = 2,
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_P = 3,
|
|
||||||
COMMON_SAMPLER_TYPE_MIN_P = 4,
|
|
||||||
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
|
|
||||||
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
|
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
|
||||||
COMMON_SAMPLER_TYPE_XTC = 8,
|
|
||||||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
|
||||||
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
|
|
||||||
};
|
|
||||||
|
|
||||||
// dimensionality reduction methods, used by cvector-generator
|
|
||||||
enum dimre_method {
|
|
||||||
DIMRE_METHOD_PCA,
|
|
||||||
DIMRE_METHOD_MEAN,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum common_conversation_mode {
|
|
||||||
COMMON_CONVERSATION_MODE_DISABLED = 0,
|
|
||||||
COMMON_CONVERSATION_MODE_ENABLED = 1,
|
|
||||||
COMMON_CONVERSATION_MODE_AUTO = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum common_grammar_trigger_type {
|
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_grammar_trigger {
|
|
||||||
common_grammar_trigger_type type;
|
|
||||||
std::string value;
|
|
||||||
llama_token token = LLAMA_TOKEN_NULL;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum common_params_sampling_config : uint64_t {
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
|
|
||||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// sampling parameters
|
|
||||||
struct common_params_sampling {
|
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
|
||||||
|
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
|
||||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
|
||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
|
||||||
float min_p = 0.05f; // 0.0 = disabled
|
|
||||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
|
||||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
|
||||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
|
||||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
|
||||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
|
||||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
|
||||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
|
||||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
|
||||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
|
||||||
float penalty_present = 0.00f; // 0.0 = disabled
|
|
||||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
|
||||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
|
||||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
|
||||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
|
||||||
float top_n_sigma = -1.00f;// -1.0 = disabled
|
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
|
||||||
bool ignore_eos = false;
|
|
||||||
bool no_perf = false; // disable performance metrics
|
|
||||||
bool timing_per_token = false;
|
|
||||||
|
|
||||||
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
|
||||||
|
|
||||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> samplers = {
|
|
||||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
|
||||||
COMMON_SAMPLER_TYPE_DRY,
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_K,
|
|
||||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_P,
|
|
||||||
COMMON_SAMPLER_TYPE_MIN_P,
|
|
||||||
COMMON_SAMPLER_TYPE_XTC,
|
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
|
||||||
bool grammar_lazy = false;
|
|
||||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
|
||||||
std::set<llama_token> preserved_tokens;
|
|
||||||
|
|
||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
|
||||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
|
||||||
|
|
||||||
bool backend_sampling = false;
|
|
||||||
|
|
||||||
bool has_logit_bias() const {
|
|
||||||
return !logit_bias.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
// print the parameters into a string
|
|
||||||
std::string print() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_params_model {
|
|
||||||
std::string path = ""; // model local path // NOLINT
|
|
||||||
std::string url = ""; // model url to download // NOLINT
|
|
||||||
std::string hf_repo = ""; // HF repo // NOLINT
|
|
||||||
std::string hf_file = ""; // HF file // NOLINT
|
|
||||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
|
||||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_params_speculative {
|
|
||||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
|
||||||
|
|
||||||
int32_t n_ctx = 0; // draft context size
|
|
||||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
|
||||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
|
||||||
float p_split = 0.1f; // speculative decoding split probability
|
|
||||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
|
||||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
|
||||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
|
||||||
|
|
||||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
|
||||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
|
||||||
|
|
||||||
struct cpu_params cpuparams;
|
|
||||||
struct cpu_params cpuparams_batch;
|
|
||||||
|
|
||||||
struct common_params_model model;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_params_vocoder {
|
|
||||||
struct common_params_model model;
|
|
||||||
|
|
||||||
std::string speaker_file = ""; // speaker file path // NOLINT
|
|
||||||
|
|
||||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_params_diffusion {
|
|
||||||
int32_t steps = 128;
|
|
||||||
bool visual_mode = false;
|
|
||||||
|
|
||||||
float eps = 0; // epsilon for timesteps
|
|
||||||
int32_t block_length = 0; // block length for generation
|
|
||||||
|
|
||||||
int32_t algorithm = 4; // default algorithm: low-confidence
|
|
||||||
float alg_temp = 0.0f; // algorithm temperature
|
|
||||||
|
|
||||||
float cfg_scale = 0; // classifier-free guidance scale
|
|
||||||
bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
|
|
||||||
};
|
|
||||||
|
|
||||||
// reasoning API response format (not to be confused as chat template's reasoning format)
|
|
||||||
enum common_reasoning_format {
|
|
||||||
COMMON_REASONING_FORMAT_NONE,
|
|
||||||
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
|
|
||||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
|
||||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
|
||||||
// do not extend this enum unless you absolutely have to
|
|
||||||
// in most cases, use COMMON_REASONING_FORMAT_AUTO
|
|
||||||
// see: https://github.com/ggml-org/llama.cpp/pull/15408
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
struct lr_opt {
|
|
||||||
float lr0 = 1e-5; // learning rate at first epoch
|
|
||||||
float lr_min = -1;
|
|
||||||
float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs
|
|
||||||
float scale_epoch = 0;
|
|
||||||
float wd = 0;
|
|
||||||
unsigned epochs = 2;
|
|
||||||
|
|
||||||
unsigned epoch; // set by optimizer outer (epochs) loop
|
|
||||||
// learning rate decay - constant LR per epoch only for now
|
|
||||||
float get_lr(float e) const;
|
|
||||||
float get_lr() const { return get_lr(epoch); }
|
|
||||||
// must call after arg parse, before get_lr
|
|
||||||
void init();
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
|
||||||
|
|
||||||
struct common_params {
|
|
||||||
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
|
|
||||||
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
|
|
||||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
|
||||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
|
||||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
|
||||||
int32_t n_parallel = 1; // number of parallel sequences to decode
|
|
||||||
int32_t n_sequences = 1; // number of sequences to decode
|
|
||||||
int32_t grp_attn_n = 1; // group-attention factor
|
|
||||||
int32_t grp_attn_w = 512; // group-attention width
|
|
||||||
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
|
|
||||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
|
||||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
|
||||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
|
||||||
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
|
|
||||||
float yarn_beta_fast = -1.0f; // YaRN low correction dim
|
|
||||||
float yarn_beta_slow = -1.0f; // YaRN high correction dim
|
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
|
||||||
|
|
||||||
// offload params
|
|
||||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
|
||||||
|
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
|
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
|
||||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
|
||||||
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
|
||||||
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
|
||||||
|
|
||||||
// margin per device in bytes for fitting parameters to free memory:
|
|
||||||
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024*1024);
|
|
||||||
|
|
||||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
|
||||||
|
|
||||||
struct cpu_params cpuparams;
|
|
||||||
struct cpu_params cpuparams_batch;
|
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
|
||||||
void * cb_eval_user_data = nullptr;
|
|
||||||
|
|
||||||
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
|
||||||
|
|
||||||
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
|
||||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
|
||||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
|
||||||
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
|
|
||||||
|
|
||||||
struct common_params_sampling sampling;
|
|
||||||
struct common_params_speculative speculative;
|
|
||||||
struct common_params_vocoder vocoder;
|
|
||||||
struct common_params_diffusion diffusion;
|
|
||||||
|
|
||||||
struct common_params_model model;
|
|
||||||
|
|
||||||
std::string model_alias = ""; // model alias // NOLINT
|
|
||||||
std::string hf_token = ""; // HF token // NOLINT
|
|
||||||
std::string prompt = ""; // NOLINT
|
|
||||||
std::string system_prompt = ""; // NOLINT
|
|
||||||
std::string prompt_file = ""; // store the external prompt file name // NOLINT
|
|
||||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
|
||||||
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
|
||||||
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
|
|
||||||
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
|
||||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
|
||||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
|
||||||
|
|
||||||
// llama-debug specific options
|
|
||||||
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
|
|
||||||
bool save_logits = false; // whether to save logits to files // NOLINT
|
|
||||||
std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex) // NOLINT
|
|
||||||
|
|
||||||
std::vector<std::string> in_files; // all input files
|
|
||||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
|
||||||
std::vector<llama_model_kv_override> kv_overrides;
|
|
||||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
|
||||||
|
|
||||||
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
|
|
||||||
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
|
|
||||||
|
|
||||||
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
|
||||||
|
|
||||||
int32_t verbosity = 3; // LOG_LEVEL_INFO
|
|
||||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
|
||||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
|
||||||
bool offline = false;
|
|
||||||
|
|
||||||
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
|
||||||
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
|
||||||
// (which is more convenient to use for plotting)
|
|
||||||
//
|
|
||||||
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
|
|
||||||
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
|
||||||
|
|
||||||
bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt
|
|
||||||
size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
|
|
||||||
|
|
||||||
bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
|
|
||||||
size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
|
|
||||||
|
|
||||||
bool kl_divergence = false; // compute KL divergence
|
|
||||||
|
|
||||||
bool usage = false; // print usage
|
|
||||||
bool completion = false; // print source-able completion script
|
|
||||||
bool use_color = false; // use color to distinguish generations and inputs
|
|
||||||
bool special = false; // enable special token output
|
|
||||||
bool interactive = false; // interactive mode
|
|
||||||
bool interactive_first = false; // wait for user input immediately
|
|
||||||
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
|
||||||
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
|
||||||
|
|
||||||
bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
|
|
||||||
bool multiline_input = false; // reverse the usage of `\`
|
|
||||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
|
||||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
|
||||||
bool no_perf = false; // disable performance metrics
|
|
||||||
bool show_timings = true; // show timing information on CLI
|
|
||||||
bool ctx_shift = false; // context shift on infinite text generation
|
|
||||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
|
||||||
bool kv_unified = false; // enable unified KV cache
|
|
||||||
|
|
||||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
|
||||||
bool use_mmap = true; // enable mmap to use filesystem cache
|
|
||||||
bool use_direct_io = true; // read from disk without buffering for faster model loading
|
|
||||||
bool use_mlock = false; // use mlock to keep model in memory
|
|
||||||
bool verbose_prompt = false; // print prompt tokens before generation
|
|
||||||
bool display_prompt = true; // print prompt before generation
|
|
||||||
bool no_kv_offload = false; // disable KV offloading
|
|
||||||
bool warmup = true; // warmup run
|
|
||||||
bool check_tensors = false; // validate tensor data
|
|
||||||
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
|
||||||
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
|
|
||||||
bool no_host = false; // bypass host buffer allowing extra buffers to be used
|
|
||||||
|
|
||||||
bool single_turn = false; // single turn chat conversation
|
|
||||||
|
|
||||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
|
||||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
|
||||||
|
|
||||||
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
|
|
||||||
|
|
||||||
// multimodal models (see tools/mtmd)
|
|
||||||
struct common_params_model mmproj;
|
|
||||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
|
||||||
bool no_mmproj = false; // explicitly disable multimodal model
|
|
||||||
std::vector<std::string> image; // path to image file(s)
|
|
||||||
int image_min_tokens = -1;
|
|
||||||
int image_max_tokens = -1;
|
|
||||||
|
|
||||||
// finetune
|
|
||||||
struct lr_opt lr;
|
|
||||||
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
||||||
float val_split = 0.05f; // fraction of the data used for the validation set
|
|
||||||
|
|
||||||
// embedding
|
|
||||||
bool embedding = false; // get only sentence embedding
|
|
||||||
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
|
||||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
|
||||||
std::string embd_sep = "\n"; // separator of embeddings
|
|
||||||
std::string cls_sep = "\t"; // separator of classification sequences
|
|
||||||
|
|
||||||
// server params
|
|
||||||
int32_t port = 8080; // server listens on this network port
|
|
||||||
int32_t timeout_read = 600; // http read timeout in seconds
|
|
||||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
|
||||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
|
||||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
|
||||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
|
||||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
|
||||||
|
|
||||||
std::string hostname = "127.0.0.1";
|
|
||||||
std::string public_path = ""; // NOLINT
|
|
||||||
std::string api_prefix = ""; // NOLINT
|
|
||||||
std::string chat_template = ""; // NOLINT
|
|
||||||
bool use_jinja = true; // NOLINT
|
|
||||||
bool enable_chat_template = true;
|
|
||||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
|
||||||
int reasoning_budget = -1;
|
|
||||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
|
||||||
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
|
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
|
||||||
|
|
||||||
std::string ssl_file_key = ""; // NOLINT
|
|
||||||
std::string ssl_file_cert = ""; // NOLINT
|
|
||||||
|
|
||||||
std::map<std::string, std::string> default_template_kwargs;
|
|
||||||
|
|
||||||
// webui configs
|
|
||||||
bool webui = true;
|
|
||||||
std::string webui_config_json;
|
|
||||||
|
|
||||||
// "advanced" endpoints are disabled by default for better security
|
|
||||||
bool endpoint_slots = true;
|
|
||||||
bool endpoint_props = false; // only control POST requests, not GET
|
|
||||||
bool endpoint_metrics = false;
|
|
||||||
|
|
||||||
// router server configs
|
|
||||||
std::string models_dir = ""; // directory containing models for the router server
|
|
||||||
std::string models_preset = ""; // directory containing model presets for the router server
|
|
||||||
int models_max = 4; // maximum number of models to load simultaneously
|
|
||||||
bool models_autoload = true; // automatically load models when requested via the router server
|
|
||||||
|
|
||||||
bool log_json = false;
|
|
||||||
|
|
||||||
std::string slot_save_path;
|
|
||||||
std::string media_path; // path to directory for loading media files
|
|
||||||
|
|
||||||
float slot_prompt_similarity = 0.1f;
|
|
||||||
|
|
||||||
// batched-bench params
|
|
||||||
bool is_pp_shared = false;
|
|
||||||
bool is_tg_separate = false;
|
|
||||||
|
|
||||||
std::vector<int32_t> n_pp;
|
|
||||||
std::vector<int32_t> n_tg;
|
|
||||||
std::vector<int32_t> n_pl;
|
|
||||||
|
|
||||||
// retrieval params
|
|
||||||
std::vector<std::string> context_files; // context files to embed
|
|
||||||
|
|
||||||
int32_t chunk_size = 64; // chunk size for context embedding
|
|
||||||
|
|
||||||
std::string chunk_separator = "\n"; // chunk separator for context embedding
|
|
||||||
|
|
||||||
// passkey params
|
|
||||||
int32_t n_junk = 250; // number of times to repeat the junk text
|
|
||||||
int32_t i_pos = -1; // position of the passkey in the junk text
|
|
||||||
|
|
||||||
// imatrix params
|
|
||||||
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
|
|
||||||
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
|
|
||||||
int32_t i_chunk = 0; // start processing from this chunk
|
|
||||||
int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat)
|
|
||||||
|
|
||||||
bool process_output = false; // collect data for the output tensor
|
|
||||||
bool compute_ppl = true; // whether to compute perplexity
|
|
||||||
bool show_statistics = false; // show imatrix statistics per tensor
|
|
||||||
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
|
||||||
|
|
||||||
// cvector-generator params
|
|
||||||
int n_pca_batch = 100;
|
|
||||||
int n_pca_iterations = 1000;
|
|
||||||
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
|
|
||||||
std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
|
|
||||||
std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
|
|
||||||
|
|
||||||
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
|
||||||
|
|
||||||
// batched-bench params
|
|
||||||
bool batched_bench_output_jsonl = false;
|
|
||||||
|
|
||||||
// common params
|
|
||||||
std::string out_file; // output filename for all example programs
|
|
||||||
// optional callback for model loading progress and cancellation:
|
|
||||||
// called with a progress value between 0.0 and 1.0.
|
|
||||||
// return false from callback to abort model loading or true to continue
|
|
||||||
llama_progress_callback load_progress_callback = NULL;
|
|
||||||
void * load_progress_callback_user_data = NULL;
|
|
||||||
|
|
||||||
bool has_speculative() const {
|
|
||||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// call once at the start of a program if it uses libcommon
|
|
||||||
// initializes the logging system and prints info about the build
|
|
||||||
void common_init();
|
|
||||||
|
|
||||||
std::string common_params_get_system_info(const common_params & params);
|
|
||||||
|
|
||||||
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
|
||||||
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
|
||||||
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
|
|
||||||
bool set_process_priority(enum ggml_sched_priority prio);
|
|
||||||
|
|
||||||
//
|
|
||||||
// String utils
|
|
||||||
//
|
|
||||||
|
|
||||||
#ifdef __GNUC__
|
|
||||||
# if defined(__MINGW32__) && !defined(__clang__)
|
|
||||||
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
|
||||||
# else
|
|
||||||
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
|
||||||
# endif
|
|
||||||
#else
|
|
||||||
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
|
||||||
std::string string_format(const char * fmt, ...);
|
|
||||||
|
|
||||||
std::string string_strip(const std::string & str);
|
|
||||||
std::string string_get_sortable_timestamp();
|
|
||||||
|
|
||||||
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
|
|
||||||
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
|
|
||||||
std::string string_repeat(const std::string & str, size_t n);
|
|
||||||
|
|
||||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
|
||||||
|
|
||||||
std::string regex_escape(const std::string & s);
|
|
||||||
|
|
||||||
template<class T>
|
|
||||||
static std::vector<T> string_split(const std::string & str, char delim) {
|
|
||||||
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
|
||||||
std::vector<T> values;
|
|
||||||
std::istringstream str_stream(str);
|
|
||||||
std::string token;
|
|
||||||
while (std::getline(str_stream, token, delim)) {
|
|
||||||
T value;
|
|
||||||
std::istringstream token_stream(token);
|
|
||||||
token_stream >> value;
|
|
||||||
values.push_back(value);
|
|
||||||
}
|
|
||||||
return values;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
|
||||||
{
|
|
||||||
std::vector<std::string> parts;
|
|
||||||
size_t begin_pos = 0;
|
|
||||||
size_t separator_pos = input.find(separator);
|
|
||||||
while (separator_pos != std::string::npos) {
|
|
||||||
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
|
||||||
parts.emplace_back(part);
|
|
||||||
begin_pos = separator_pos + 1;
|
|
||||||
separator_pos = input.find(separator, begin_pos);
|
|
||||||
}
|
|
||||||
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
|
||||||
return parts;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool string_starts_with(const std::string & str,
|
|
||||||
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
|
|
||||||
return str.rfind(prefix, 0) == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// While we wait for C++20's std::string::ends_with...
|
|
||||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
|
||||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
|
|
||||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
|
||||||
|
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
|
||||||
void string_process_escapes(std::string & input);
|
|
||||||
|
|
||||||
std::string string_from(bool value);
|
|
||||||
std::string string_from(const std::vector<int> & values);
|
|
||||||
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
|
|
||||||
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Filesystem utils
|
|
||||||
//
|
|
||||||
|
|
||||||
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
|
|
||||||
bool fs_create_directory_with_parents(const std::string & path);
|
|
||||||
bool fs_is_directory(const std::string & path);
|
|
||||||
|
|
||||||
std::string fs_get_cache_directory();
|
|
||||||
std::string fs_get_cache_file(const std::string & filename);
|
|
||||||
|
|
||||||
struct common_file_info {
|
|
||||||
std::string path;
|
|
||||||
std::string name;
|
|
||||||
size_t size = 0; // in bytes
|
|
||||||
bool is_dir = false;
|
|
||||||
};
|
|
||||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
|
||||||
|
|
||||||
//
|
|
||||||
// TTY utils
|
|
||||||
//
|
|
||||||
|
|
||||||
// Auto-detect if colors can be enabled based on terminal and environment
|
|
||||||
bool tty_can_use_colors();
|
|
||||||
|
|
||||||
//
|
|
||||||
// Model utils
|
|
||||||
//
|
|
||||||
|
|
||||||
struct common_sampler;
|
|
||||||
|
|
||||||
// note: defines the model, context, samplers, ets. lifetimes
|
|
||||||
struct common_init_result {
|
|
||||||
common_init_result(common_params & params);
|
|
||||||
~common_init_result();
|
|
||||||
|
|
||||||
llama_model * model();
|
|
||||||
llama_context * context();
|
|
||||||
|
|
||||||
common_sampler * sampler(llama_seq_id seq_id);
|
|
||||||
void reset_samplers();
|
|
||||||
|
|
||||||
std::vector<llama_adapter_lora_ptr> & lora();
|
|
||||||
|
|
||||||
void free_context();
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct impl;
|
|
||||||
std::unique_ptr<impl> pimpl;
|
|
||||||
};
|
|
||||||
|
|
||||||
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
|
||||||
|
|
||||||
common_init_result_ptr common_init_from_params(common_params & params);
|
|
||||||
|
|
||||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
|
||||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
|
||||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
|
||||||
|
|
||||||
// clear LoRA adapters from context, then apply new list of adapters
|
|
||||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
|
||||||
|
|
||||||
std::string get_model_endpoint();
|
|
||||||
|
|
||||||
//
|
|
||||||
// Batch utils
|
|
||||||
//
|
|
||||||
|
|
||||||
void common_batch_clear(struct llama_batch & batch);
|
|
||||||
|
|
||||||
void common_batch_add(
|
|
||||||
struct llama_batch & batch,
|
|
||||||
llama_token id,
|
|
||||||
llama_pos pos,
|
|
||||||
const std::vector<llama_seq_id> & seq_ids,
|
|
||||||
bool logits);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Token utils
|
|
||||||
//
|
|
||||||
|
|
||||||
// longest common prefix
|
|
||||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
|
|
||||||
|
|
||||||
// longet common subsequence
|
|
||||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Vocab utils
|
|
||||||
//
|
|
||||||
|
|
||||||
// tokenizes a string into a vector of tokens
|
|
||||||
// should work similar to Python's `tokenizer.encode`
|
|
||||||
std::vector<llama_token> common_tokenize(
|
|
||||||
const struct llama_context * ctx,
|
|
||||||
const std::string & text,
|
|
||||||
bool add_special,
|
|
||||||
bool parse_special = false);
|
|
||||||
|
|
||||||
std::vector<llama_token> common_tokenize(
|
|
||||||
const struct llama_vocab * vocab,
|
|
||||||
const std::string & text,
|
|
||||||
bool add_special,
|
|
||||||
bool parse_special = false);
|
|
||||||
|
|
||||||
// tokenizes a token into a piece, optionally renders special/control tokens
|
|
||||||
// should work similar to Python's `tokenizer.id_to_piece`
|
|
||||||
std::string common_token_to_piece(
|
|
||||||
const struct llama_context * ctx,
|
|
||||||
llama_token token,
|
|
||||||
bool special = true);
|
|
||||||
|
|
||||||
std::string common_token_to_piece(
|
|
||||||
const struct llama_vocab * vocab,
|
|
||||||
llama_token token,
|
|
||||||
bool special = true);
|
|
||||||
|
|
||||||
// detokenizes a vector of tokens into a string
|
|
||||||
// should work similar to Python's `tokenizer.decode`
|
|
||||||
// optionally renders special/control tokens
|
|
||||||
std::string common_detokenize(
|
|
||||||
const struct llama_context * ctx,
|
|
||||||
const std::vector<llama_token> & tokens,
|
|
||||||
bool special = true);
|
|
||||||
|
|
||||||
std::string common_detokenize(
|
|
||||||
const struct llama_vocab * vocab,
|
|
||||||
const std::vector<llama_token> & tokens,
|
|
||||||
bool special = true);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Embedding utils
|
|
||||||
//
|
|
||||||
|
|
||||||
// TODO: repace embd_norm with an enum
|
|
||||||
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
|
|
||||||
|
|
||||||
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Control vector utils
|
|
||||||
//
|
|
||||||
|
|
||||||
struct common_control_vector_data {
|
|
||||||
int n_embd;
|
|
||||||
|
|
||||||
// stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
|
|
||||||
std::vector<float> data;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_control_vector_load_info {
|
|
||||||
float strength;
|
|
||||||
|
|
||||||
std::string fname;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load control vectors, scale each by strength, and add them together.
|
|
||||||
// On error, returns {-1, empty}
|
|
||||||
common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Split utils
|
|
||||||
//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
const char * const LLM_KV_SPLIT_NO = "split.no";
|
|
||||||
const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
|
||||||
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// MoE utils
|
|
||||||
//
|
|
||||||
|
|
||||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
|
|
||||||
|
|
||||||
static std::string llm_ffn_exps_block_regex(int idx) {
|
|
||||||
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
|
||||||
}
|
|
||||||
|
|
||||||
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
|
||||||
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// training utils
|
|
||||||
//
|
|
||||||
|
|
||||||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
|
||||||
|
|
||||||
// "adamw" or "sgd" (case insensitive)
|
|
||||||
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,41 +0,0 @@
|
|||||||
// Console functions
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
enum display_type {
|
|
||||||
DISPLAY_TYPE_RESET = 0,
|
|
||||||
DISPLAY_TYPE_INFO,
|
|
||||||
DISPLAY_TYPE_PROMPT,
|
|
||||||
DISPLAY_TYPE_REASONING,
|
|
||||||
DISPLAY_TYPE_USER_INPUT,
|
|
||||||
DISPLAY_TYPE_ERROR
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace console {
|
|
||||||
void init(bool use_simple_io, bool use_advanced_display);
|
|
||||||
void cleanup();
|
|
||||||
void set_display(display_type display);
|
|
||||||
bool readline(std::string & line, bool multiline_input);
|
|
||||||
|
|
||||||
namespace spinner {
|
|
||||||
void start();
|
|
||||||
void stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
// note: the logging API below output directly to stdout
|
|
||||||
// it can negatively impact performance if used on inference thread
|
|
||||||
// only use in in a dedicated CLI thread
|
|
||||||
// for logging in inference thread, use log.h instead
|
|
||||||
|
|
||||||
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
|
||||||
void log(const char * fmt, ...);
|
|
||||||
|
|
||||||
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
|
||||||
void error(const char * fmt, ...);
|
|
||||||
|
|
||||||
void flush();
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,84 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
struct common_params_model;
|
|
||||||
|
|
||||||
using common_header = std::pair<std::string, std::string>;
|
|
||||||
using common_header_list = std::vector<common_header>;
|
|
||||||
|
|
||||||
struct common_remote_params {
|
|
||||||
common_header_list headers;
|
|
||||||
long timeout = 0; // in seconds, 0 means no timeout
|
|
||||||
long max_size = 0; // unlimited if 0
|
|
||||||
};
|
|
||||||
|
|
||||||
// get remote file content, returns <http_code, raw_response_body>
|
|
||||||
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
|
|
||||||
|
|
||||||
// split HF repo with tag into <repo, tag>
|
|
||||||
// for example: "user/model:tag" -> <"user/model", "tag">
|
|
||||||
// if tag is not present, default to "latest"
|
|
||||||
// example: "user/model" -> <"user/model", "latest">
|
|
||||||
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag);
|
|
||||||
|
|
||||||
struct common_cached_model_info {
|
|
||||||
std::string manifest_path;
|
|
||||||
std::string user;
|
|
||||||
std::string model;
|
|
||||||
std::string tag;
|
|
||||||
size_t size = 0; // GGUF size in bytes
|
|
||||||
// return string representation like "user/model:tag"
|
|
||||||
// if tag is "latest", it will be omitted
|
|
||||||
std::string to_string() const {
|
|
||||||
return user + "/" + model + (tag == "latest" ? "" : ":" + tag);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_hf_file_res {
|
|
||||||
std::string repo; // repo name with ":tag" removed
|
|
||||||
std::string ggufFile;
|
|
||||||
std::string mmprojFile;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
|
|
||||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
|
||||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
|
|
||||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
|
|
||||||
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
|
|
||||||
*
|
|
||||||
* Return pair of <repo, file> (with "repo" already having tag removed)
|
|
||||||
*
|
|
||||||
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
|
|
||||||
*/
|
|
||||||
common_hf_file_res common_get_hf_file(
|
|
||||||
const std::string & hf_repo_with_tag,
|
|
||||||
const std::string & bearer_token,
|
|
||||||
bool offline,
|
|
||||||
const common_header_list & headers = {}
|
|
||||||
);
|
|
||||||
|
|
||||||
// returns true if download succeeded
|
|
||||||
bool common_download_model(
|
|
||||||
const common_params_model & model,
|
|
||||||
const std::string & bearer_token,
|
|
||||||
bool offline,
|
|
||||||
const common_header_list & headers = {}
|
|
||||||
);
|
|
||||||
|
|
||||||
// returns list of cached models
|
|
||||||
std::vector<common_cached_model_info> common_list_cached_models();
|
|
||||||
|
|
||||||
// download single file from url to local path
|
|
||||||
// returns status code or -1 on error
|
|
||||||
int common_download_file_single(const std::string & url,
|
|
||||||
const std::string & path,
|
|
||||||
const std::string & bearer_token,
|
|
||||||
bool offline,
|
|
||||||
const common_header_list & headers = {});
|
|
||||||
|
|
||||||
// resolve and download model from Docker registry
|
|
||||||
// return local path to downloaded model file
|
|
||||||
std::string common_docker_resolve_model(const std::string & docker);
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cpp-httplib/httplib.h>
|
|
||||||
|
|
||||||
struct common_http_url {
|
|
||||||
std::string scheme;
|
|
||||||
std::string user;
|
|
||||||
std::string password;
|
|
||||||
std::string host;
|
|
||||||
std::string path;
|
|
||||||
};
|
|
||||||
|
|
||||||
static common_http_url common_http_parse_url(const std::string & url) {
|
|
||||||
common_http_url parts;
|
|
||||||
auto scheme_end = url.find("://");
|
|
||||||
|
|
||||||
if (scheme_end == std::string::npos) {
|
|
||||||
throw std::runtime_error("invalid URL: no scheme");
|
|
||||||
}
|
|
||||||
parts.scheme = url.substr(0, scheme_end);
|
|
||||||
|
|
||||||
if (parts.scheme != "http" && parts.scheme != "https") {
|
|
||||||
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto rest = url.substr(scheme_end + 3);
|
|
||||||
auto at_pos = rest.find('@');
|
|
||||||
|
|
||||||
if (at_pos != std::string::npos) {
|
|
||||||
auto auth = rest.substr(0, at_pos);
|
|
||||||
auto colon_pos = auth.find(':');
|
|
||||||
if (colon_pos != std::string::npos) {
|
|
||||||
parts.user = auth.substr(0, colon_pos);
|
|
||||||
parts.password = auth.substr(colon_pos + 1);
|
|
||||||
} else {
|
|
||||||
parts.user = auth;
|
|
||||||
}
|
|
||||||
rest = rest.substr(at_pos + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto slash_pos = rest.find('/');
|
|
||||||
|
|
||||||
if (slash_pos != std::string::npos) {
|
|
||||||
parts.host = rest.substr(0, slash_pos);
|
|
||||||
parts.path = rest.substr(slash_pos);
|
|
||||||
} else {
|
|
||||||
parts.host = rest;
|
|
||||||
parts.path = "/";
|
|
||||||
}
|
|
||||||
return parts;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::pair<httplib::Client, common_http_url> common_http_client(const std::string & url) {
|
|
||||||
common_http_url parts = common_http_parse_url(url);
|
|
||||||
|
|
||||||
if (parts.host.empty()) {
|
|
||||||
throw std::runtime_error("error: invalid URL format");
|
|
||||||
}
|
|
||||||
|
|
||||||
httplib::Client cli(parts.scheme + "://" + parts.host);
|
|
||||||
|
|
||||||
if (!parts.user.empty()) {
|
|
||||||
cli.set_basic_auth(parts.user, parts.password);
|
|
||||||
}
|
|
||||||
|
|
||||||
cli.set_follow_location(true);
|
|
||||||
|
|
||||||
return { std::move(cli), std::move(parts) };
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string common_http_show_masked_url(const common_http_url & parts) {
|
|
||||||
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
|
||||||
}
|
|
||||||
@@ -1,324 +0,0 @@
|
|||||||
#include "json-partial.h"
|
|
||||||
|
|
||||||
#include "log.h"
|
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <regex>
|
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
|
||||||
|
|
||||||
enum common_json_stack_element_type {
|
|
||||||
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
|
||||||
COMMON_JSON_STACK_ELEMENT_KEY,
|
|
||||||
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_json_stack_element {
|
|
||||||
common_json_stack_element_type type;
|
|
||||||
std::string key;
|
|
||||||
};
|
|
||||||
|
|
||||||
bool common_json_parse(
|
|
||||||
const std::string & input,
|
|
||||||
const std::string & healing_marker,
|
|
||||||
common_json & out)
|
|
||||||
{
|
|
||||||
std::string::const_iterator it = input.begin();
|
|
||||||
const auto end = input.end();
|
|
||||||
return common_json_parse(it, end, healing_marker, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool common_json_parse(
|
|
||||||
std::string::const_iterator & it,
|
|
||||||
const std::string::const_iterator & end,
|
|
||||||
const std::string & healing_marker,
|
|
||||||
common_json & out)
|
|
||||||
{
|
|
||||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
|
||||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
|
||||||
std::size_t position;
|
|
||||||
bool found_error;
|
|
||||||
std::string last_token;
|
|
||||||
std::string exception_message;
|
|
||||||
std::vector<common_json_stack_element> stack;
|
|
||||||
|
|
||||||
json_error_locator() : position(0), found_error(false) {}
|
|
||||||
|
|
||||||
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
|
||||||
this->position = position - 1;
|
|
||||||
this->found_error = true;
|
|
||||||
this->last_token = last_token;
|
|
||||||
this->exception_message = ex.what();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
void close_value() {
|
|
||||||
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
|
||||||
stack.pop_back();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool null() override { // NOLINT
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool boolean(bool) override { // NOLINT
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool number_integer(number_integer_t) override { // NOLINT
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool string(string_t &) override { // NOLINT
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool binary(binary_t &) override { // NOLINT
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool start_object(std::size_t) override { // NOLINT
|
|
||||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool end_object() override {
|
|
||||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
|
||||||
stack.pop_back();
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool key(string_t & key) override { // NOLINT
|
|
||||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool start_array(std::size_t) override { // NOLINT
|
|
||||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool end_array() override {
|
|
||||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
|
||||||
stack.pop_back();
|
|
||||||
close_value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
json_error_locator err_loc;
|
|
||||||
auto start = it;
|
|
||||||
json::sax_parse(it, end, &err_loc);
|
|
||||||
|
|
||||||
if (err_loc.found_error) {
|
|
||||||
it = start;
|
|
||||||
auto temptative_end = it + err_loc.position;
|
|
||||||
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
|
||||||
|
|
||||||
auto input = std::string(it, temptative_end);
|
|
||||||
try {
|
|
||||||
out.json = json::parse(input);
|
|
||||||
// out.json = json::parse(it, temptative_end);
|
|
||||||
it = temptative_end;
|
|
||||||
return true;
|
|
||||||
} catch (const std::exception & ex) {
|
|
||||||
// No, needs healing.
|
|
||||||
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
|
||||||
}
|
|
||||||
auto can_parse = [](const std::string & str) {
|
|
||||||
try {
|
|
||||||
auto _ = json::parse(str); // NOLINT
|
|
||||||
return true;
|
|
||||||
} catch (const std::exception &) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
|
||||||
std::string str(it, temptative_end);
|
|
||||||
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
|
||||||
if (last_non_sp_pos == std::string::npos) {
|
|
||||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
|
||||||
}
|
|
||||||
auto last_non_sp_char = str[last_non_sp_pos];
|
|
||||||
// Used to detect stops on a number, which may not be complete.
|
|
||||||
auto was_maybe_number = [&]() {
|
|
||||||
if (!str.empty() && std::isspace(str.back())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return std::isdigit(last_non_sp_char) ||
|
|
||||||
last_non_sp_char == '.' ||
|
|
||||||
last_non_sp_char == 'e' ||
|
|
||||||
last_non_sp_char == 'E' ||
|
|
||||||
last_non_sp_char == '-';
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string closing;
|
|
||||||
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
|
||||||
auto & el = err_loc.stack[i - 1];
|
|
||||||
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
|
||||||
closing += "}";
|
|
||||||
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
|
||||||
closing += "]";
|
|
||||||
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
|
||||||
throw std::runtime_error("Unexpected stack element type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
|
||||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
|
||||||
|
|
||||||
auto is_high_surrogate = [&](const std::string & s) {
|
|
||||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
|
||||||
return s.length() >= 4 &&
|
|
||||||
s[0] == '\\' && s[1] == 'u' &&
|
|
||||||
std::tolower(s[2]) == 'd' &&
|
|
||||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
|
||||||
};
|
|
||||||
|
|
||||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
|
||||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
|
||||||
// backslash (\)
|
|
||||||
std::string unicode_marker_padding = "udc00";
|
|
||||||
std::smatch last_unicode_seq;
|
|
||||||
|
|
||||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
|
||||||
std::smatch second_last_seq;
|
|
||||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
|
||||||
|
|
||||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
|
||||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
|
||||||
|
|
||||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
|
||||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
|
||||||
unicode_marker_padding += "\\udc00";
|
|
||||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
|
||||||
if (is_high_surrogate(second_last_seq.str())) {
|
|
||||||
// If this follows a high surrogate, pad it to be a low surrogate
|
|
||||||
if (last_unicode_seq.length() == 2) {
|
|
||||||
unicode_marker_padding = "dc00";
|
|
||||||
} else if (last_unicode_seq.length() == 3) {
|
|
||||||
unicode_marker_padding = "c00";
|
|
||||||
} else {
|
|
||||||
// The original unicode_marker_padding is already padded with 0s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
|
||||||
|
|
||||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
|
||||||
// We're inside an object value
|
|
||||||
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
|
||||||
// Was about to create an object value
|
|
||||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
|
||||||
} else if (can_parse(str + ": 1" + closing)) {
|
|
||||||
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
|
||||||
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
|
||||||
// Was about to create an object
|
|
||||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
|
||||||
} else if (can_parse(str + "\"" + closing)) {
|
|
||||||
// Was inside an object value string
|
|
||||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
|
||||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
|
||||||
// Was inside an object value string after an escape
|
|
||||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
|
||||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
|
||||||
// Was inside an object value string after a partial unicode escape
|
|
||||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
|
||||||
} else {
|
|
||||||
// find last :
|
|
||||||
auto last_pos = str.find_last_of(':');
|
|
||||||
if (last_pos == std::string::npos) {
|
|
||||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
|
||||||
}
|
|
||||||
// Cutting back to opening : for object value
|
|
||||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
|
||||||
}
|
|
||||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
|
||||||
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
|
||||||
// Was about to create an array value
|
|
||||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
|
||||||
} else if (can_parse(str + "\"" + closing)) {
|
|
||||||
// Was inside an array value string
|
|
||||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
|
||||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
|
||||||
// Was inside an array value string after an escape
|
|
||||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
|
||||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
|
||||||
// Was inside an array value string after a partial unicode escape
|
|
||||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
|
||||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
|
||||||
// Had just finished a value
|
|
||||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
|
||||||
} else {
|
|
||||||
auto last_pos = str.find_last_of("[,");
|
|
||||||
if (last_pos == std::string::npos) {
|
|
||||||
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
|
||||||
}
|
|
||||||
// Cutting back to last [ or , for array value
|
|
||||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
|
||||||
}
|
|
||||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
|
||||||
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
|
||||||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
|
||||||
// Was about to create an object key+value
|
|
||||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
|
||||||
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
|
||||||
// Was about to create an object key+value
|
|
||||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
|
||||||
} else if (can_parse(str + "\": 1" + closing)) {
|
|
||||||
// Was inside an object key string
|
|
||||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
|
||||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
|
||||||
// Was inside an object key string after an escape
|
|
||||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
|
||||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
|
||||||
// Was inside an object key string after a partial unicode escape
|
|
||||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
|
||||||
} else {
|
|
||||||
auto last_pos = str.find_last_of(':');
|
|
||||||
if (last_pos == std::string::npos) {
|
|
||||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
|
||||||
}
|
|
||||||
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
|
||||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
|
||||||
}
|
|
||||||
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
|
||||||
out.json = json::parse(str);
|
|
||||||
it = temptative_end;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// handle unclosed top-level primitive
|
|
||||||
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
|
||||||
std::string str(it, temptative_end);
|
|
||||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
|
||||||
if (can_parse(str + "\"")) {
|
|
||||||
// Was inside an string
|
|
||||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
|
||||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
|
||||||
// Was inside an string after an escape
|
|
||||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
|
||||||
} else {
|
|
||||||
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
|
||||||
// fprintf(stderr, "Closing: TODO\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
out.json = json::parse(str);
|
|
||||||
it = temptative_end;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
out.json = json::parse(it, end);
|
|
||||||
it = end;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
|
|
||||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
|
||||||
struct common_healing_marker {
|
|
||||||
// Raw marker.
|
|
||||||
std::string marker;
|
|
||||||
|
|
||||||
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
|
||||||
std::string json_dump_marker;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
|
||||||
struct common_json {
|
|
||||||
nlohmann::ordered_json json;
|
|
||||||
|
|
||||||
common_healing_marker healing_marker;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
|
||||||
//
|
|
||||||
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
|
||||||
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
|
||||||
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
|
||||||
//
|
|
||||||
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
|
||||||
bool common_json_parse(
|
|
||||||
const std::string & input,
|
|
||||||
const std::string & healing_marker,
|
|
||||||
common_json & out);
|
|
||||||
|
|
||||||
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
|
||||||
bool common_json_parse(
|
|
||||||
std::string::const_iterator & it,
|
|
||||||
const std::string::const_iterator & end,
|
|
||||||
const std::string & healing_marker,
|
|
||||||
common_json & out);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,43 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <nlohmann/json_fwd.hpp>
|
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
|
||||||
bool force_gbnf = false);
|
|
||||||
|
|
||||||
class common_schema_converter;
|
|
||||||
|
|
||||||
// Probes a JSON schema to extract information about its structure and type constraints.
|
|
||||||
class common_schema_info {
|
|
||||||
std::unique_ptr<common_schema_converter> impl_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
common_schema_info();
|
|
||||||
~common_schema_info();
|
|
||||||
|
|
||||||
common_schema_info(const common_schema_info &) = delete;
|
|
||||||
common_schema_info & operator=(const common_schema_info &) = delete;
|
|
||||||
common_schema_info(common_schema_info &&) noexcept;
|
|
||||||
common_schema_info & operator=(common_schema_info &&) noexcept;
|
|
||||||
|
|
||||||
void resolve_refs(nlohmann::ordered_json & schema);
|
|
||||||
bool resolves_to_string(const nlohmann::ordered_json & schema);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_grammar_builder {
|
|
||||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
|
||||||
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
|
||||||
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_grammar_options {
|
|
||||||
bool dotall = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string gbnf_format_literal(const std::string & literal);
|
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
|
||||||
@@ -1,258 +0,0 @@
|
|||||||
#include "sampling.h"
|
|
||||||
#include "log.h"
|
|
||||||
|
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
|
||||||
|
|
||||||
# include "llguidance.h"
|
|
||||||
# include <cmath>
|
|
||||||
|
|
||||||
struct llama_sampler_llg {
|
|
||||||
const llama_vocab * vocab;
|
|
||||||
std::string grammar_kind;
|
|
||||||
std::string grammar_data;
|
|
||||||
LlgTokenizer * tokenizer;
|
|
||||||
LlgMatcher * grammar;
|
|
||||||
};
|
|
||||||
|
|
||||||
static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
|
|
||||||
const char * grammar_data) {
|
|
||||||
LlgConstraintInit cinit;
|
|
||||||
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
|
||||||
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
|
|
||||||
if (log_level && *log_level) {
|
|
||||||
cinit.log_stderr_level = atoi(log_level);
|
|
||||||
}
|
|
||||||
auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
|
|
||||||
if (llg_matcher_get_error(c)) {
|
|
||||||
LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
|
|
||||||
llg_free_matcher(c);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
|
|
||||||
return "llguidance";
|
|
||||||
}
|
|
||||||
|
|
||||||
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
|
|
||||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
||||||
if (ctx->grammar) {
|
|
||||||
llg_matcher_consume_token(ctx->grammar, token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
||||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
||||||
if (ctx->grammar) {
|
|
||||||
const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
|
|
||||||
if (mask == nullptr) {
|
|
||||||
if (llg_matcher_compute_mask(ctx->grammar) == 0) {
|
|
||||||
mask = llg_matcher_get_mask(ctx->grammar);
|
|
||||||
} else {
|
|
||||||
LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
|
|
||||||
llg_free_matcher(ctx->grammar);
|
|
||||||
ctx->grammar = nullptr;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
||||||
auto token = cur_p->data[i].id;
|
|
||||||
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
|
|
||||||
cur_p->data[i].logit = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
|
||||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
||||||
if (ctx->grammar) {
|
|
||||||
llg_matcher_reset(ctx->grammar);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
|
||||||
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
|
||||||
|
|
||||||
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
|
|
||||||
|
|
||||||
// copy the state
|
|
||||||
{
|
|
||||||
auto * result_ctx = (llama_sampler_llg *) result->ctx;
|
|
||||||
|
|
||||||
if (ctx->grammar) {
|
|
||||||
result_ctx->grammar_kind = ctx->grammar_kind;
|
|
||||||
result_ctx->grammar_data = ctx->grammar_data;
|
|
||||||
result_ctx->grammar = llg_clone_matcher(ctx->grammar);
|
|
||||||
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void llama_sampler_llg_free(llama_sampler * smpl) {
|
|
||||||
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
|
||||||
|
|
||||||
if (ctx->grammar) {
|
|
||||||
llg_free_matcher(ctx->grammar);
|
|
||||||
llg_free_tokenizer(ctx->tokenizer);
|
|
||||||
}
|
|
||||||
|
|
||||||
delete ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
static llama_sampler_i llama_sampler_llg_i = {
|
|
||||||
/* .name = */ llama_sampler_llg_name,
|
|
||||||
/* .accept = */ llama_sampler_llg_accept_impl,
|
|
||||||
/* .apply = */ llama_sampler_llg_apply,
|
|
||||||
/* .reset = */ llama_sampler_llg_reset,
|
|
||||||
/* .clone = */ llama_sampler_llg_clone,
|
|
||||||
/* .free = */ llama_sampler_llg_free,
|
|
||||||
/* .backend_init = */ NULL,
|
|
||||||
/* .backend_accept = */ NULL,
|
|
||||||
/* .backend_apply = */ NULL,
|
|
||||||
/* .backend_set_input = */ NULL,
|
|
||||||
};
|
|
||||||
|
|
||||||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
|
||||||
uint32_t * output_tokens, size_t output_tokens_len) {
|
|
||||||
const llama_vocab * vocab = (const llama_vocab *) user_data;
|
|
||||||
int r = 0;
|
|
||||||
try {
|
|
||||||
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
|
|
||||||
true);
|
|
||||||
} catch (const std::exception & e) {
|
|
||||||
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
|
|
||||||
}
|
|
||||||
if (r < 0) {
|
|
||||||
return -r;
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
|
||||||
// TODO store the tokenizer in the vocab somehow
|
|
||||||
static const llama_vocab * vocab_cache;
|
|
||||||
static LlgTokenizer * tokenizer_cache;
|
|
||||||
|
|
||||||
if (vocab_cache == vocab) {
|
|
||||||
return llg_clone_tokenizer(tokenizer_cache);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tok_eos = llama_vocab_eot(vocab);
|
|
||||||
if (tok_eos == LLAMA_TOKEN_NULL) {
|
|
||||||
tok_eos = llama_vocab_eos(vocab);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
|
||||||
|
|
||||||
auto token_lens = new uint32_t[vocab_size];
|
|
||||||
// we typically have ~7 bytes per token; let's go on the safe side here
|
|
||||||
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
|
|
||||||
auto token_bytes = new uint8_t[token_bytes_size];
|
|
||||||
|
|
||||||
size_t offset = 0;
|
|
||||||
for (size_t i = 0; i < vocab_size; i++) {
|
|
||||||
size_t max_token = 1024;
|
|
||||||
if (token_bytes_size - offset < max_token) {
|
|
||||||
GGML_ABORT("token_bytes buffer too small\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token token = i;
|
|
||||||
auto dp = (char *) token_bytes + offset;
|
|
||||||
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
|
||||||
if (size < 0) {
|
|
||||||
GGML_ABORT("llama_detokenize failed\n");
|
|
||||||
}
|
|
||||||
if (size == 0) {
|
|
||||||
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
|
|
||||||
if (size < 0) {
|
|
||||||
GGML_ABORT("llama_detokenize failed\n");
|
|
||||||
}
|
|
||||||
if (size != 0) {
|
|
||||||
*dp = '\xff'; // special token prefix marker
|
|
||||||
size += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
token_lens[i] = size;
|
|
||||||
offset += size;
|
|
||||||
}
|
|
||||||
|
|
||||||
LlgTokenizerInit tinit = {
|
|
||||||
/* .vocab_size = */ (uint32_t) vocab_size,
|
|
||||||
/* .tok_eos = */ (uint32_t) tok_eos,
|
|
||||||
/* .token_lens = */ token_lens,
|
|
||||||
/* .token_bytes = */ token_bytes,
|
|
||||||
/* .tokenizer_json = */ nullptr,
|
|
||||||
/* .tokenize_assumes_string = */ true,
|
|
||||||
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
|
||||||
/* .use_approximate_greedy_tokenize_fn = */ false,
|
|
||||||
/* .tokenize_user_data = */ vocab,
|
|
||||||
/* .slices = */ nullptr,
|
|
||||||
};
|
|
||||||
|
|
||||||
char error_buffer[1024];
|
|
||||||
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
|
||||||
|
|
||||||
delete[] token_bytes;
|
|
||||||
delete[] token_lens;
|
|
||||||
|
|
||||||
if (tokenizer == nullptr) {
|
|
||||||
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
|
|
||||||
return tokenizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tokenizer_cache) {
|
|
||||||
llg_free_tokenizer(tokenizer_cache);
|
|
||||||
}
|
|
||||||
vocab_cache = vocab;
|
|
||||||
tokenizer_cache = tokenizer;
|
|
||||||
|
|
||||||
return llg_clone_tokenizer(tokenizer_cache);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
|
|
||||||
const char * grammar_data) {
|
|
||||||
auto * ctx = new llama_sampler_llg;
|
|
||||||
|
|
||||||
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
|
||||||
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
|
|
||||||
*ctx = {
|
|
||||||
/* .vocab = */ vocab,
|
|
||||||
/* .grammar_kind = */ grammar_kind,
|
|
||||||
/* .grammar_data = */ grammar_data,
|
|
||||||
/* .tokenizer = */ tokenizer,
|
|
||||||
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
|
|
||||||
};
|
|
||||||
if (ctx->grammar) {
|
|
||||||
GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
|
|
||||||
llg_matcher_get_mask_byte_size(ctx->grammar));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
*ctx = {
|
|
||||||
/* .vocab = */ vocab,
|
|
||||||
/* .grammar_kind = */ {},
|
|
||||||
/* .grammar_data = */ {},
|
|
||||||
/* .tokenizer = */ nullptr,
|
|
||||||
/* .grammar = */ nullptr,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return llama_sampler_init(
|
|
||||||
/* .iface = */ &llama_sampler_llg_i,
|
|
||||||
/* .ctx = */ ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
|
|
||||||
LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
|
||||||
@@ -1,446 +0,0 @@
|
|||||||
#include "common.h"
|
|
||||||
#include "log.h"
|
|
||||||
|
|
||||||
#include <chrono>
|
|
||||||
#include <condition_variable>
|
|
||||||
#include <cstdarg>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <cstring>
|
|
||||||
#include <mutex>
|
|
||||||
#include <sstream>
|
|
||||||
#include <thread>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
# include <io.h>
|
|
||||||
# include <windows.h>
|
|
||||||
# define isatty _isatty
|
|
||||||
# define fileno _fileno
|
|
||||||
#else
|
|
||||||
# include <unistd.h>
|
|
||||||
#endif // defined(_WIN32)
|
|
||||||
|
|
||||||
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
|
|
||||||
|
|
||||||
void common_log_set_verbosity_thold(int verbosity) {
|
|
||||||
common_log_verbosity_thold = verbosity;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int64_t t_us() {
|
|
||||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
|
||||||
}
|
|
||||||
|
|
||||||
// colors
|
|
||||||
enum common_log_col : int {
|
|
||||||
COMMON_LOG_COL_DEFAULT = 0,
|
|
||||||
COMMON_LOG_COL_BOLD,
|
|
||||||
COMMON_LOG_COL_RED,
|
|
||||||
COMMON_LOG_COL_GREEN,
|
|
||||||
COMMON_LOG_COL_YELLOW,
|
|
||||||
COMMON_LOG_COL_BLUE,
|
|
||||||
COMMON_LOG_COL_MAGENTA,
|
|
||||||
COMMON_LOG_COL_CYAN,
|
|
||||||
COMMON_LOG_COL_WHITE,
|
|
||||||
};
|
|
||||||
|
|
||||||
// disable colors by default
|
|
||||||
static std::vector<const char *> g_col = {
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_log_entry {
|
|
||||||
enum ggml_log_level level;
|
|
||||||
|
|
||||||
bool prefix;
|
|
||||||
|
|
||||||
int64_t timestamp;
|
|
||||||
|
|
||||||
std::vector<char> msg;
|
|
||||||
|
|
||||||
// signals the worker thread to stop
|
|
||||||
bool is_end;
|
|
||||||
|
|
||||||
void print(FILE * file = nullptr) const {
|
|
||||||
FILE * fcur = file;
|
|
||||||
if (!fcur) {
|
|
||||||
// stderr displays DBG messages only when their verbosity level is not higher than the threshold
|
|
||||||
// these messages will still be logged to a file
|
|
||||||
if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
fcur = stdout;
|
|
||||||
|
|
||||||
if (level != GGML_LOG_LEVEL_NONE) {
|
|
||||||
fcur = stderr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
|
|
||||||
if (timestamp) {
|
|
||||||
// [M.s.ms.us]
|
|
||||||
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
|
|
||||||
g_col[COMMON_LOG_COL_BLUE],
|
|
||||||
(int) (timestamp / 1000000 / 60),
|
|
||||||
(int) (timestamp / 1000000 % 60),
|
|
||||||
(int) (timestamp / 1000 % 1000),
|
|
||||||
(int) (timestamp % 1000),
|
|
||||||
g_col[COMMON_LOG_COL_DEFAULT]);
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (level) {
|
|
||||||
case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break;
|
|
||||||
case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break;
|
|
||||||
case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break;
|
|
||||||
case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(fcur, "%s", msg.data());
|
|
||||||
|
|
||||||
if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) {
|
|
||||||
fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fflush(fcur);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_log {
|
|
||||||
// default capacity - will be expanded if needed
|
|
||||||
common_log() : common_log(256) {}
|
|
||||||
|
|
||||||
common_log(size_t capacity) {
|
|
||||||
file = nullptr;
|
|
||||||
prefix = false;
|
|
||||||
timestamps = false;
|
|
||||||
running = false;
|
|
||||||
t_start = t_us();
|
|
||||||
|
|
||||||
// initial message size - will be expanded if longer messages arrive
|
|
||||||
entries.resize(capacity);
|
|
||||||
for (auto & entry : entries) {
|
|
||||||
entry.msg.resize(256);
|
|
||||||
}
|
|
||||||
|
|
||||||
head = 0;
|
|
||||||
tail = 0;
|
|
||||||
|
|
||||||
resume();
|
|
||||||
}
|
|
||||||
|
|
||||||
~common_log() {
|
|
||||||
pause();
|
|
||||||
if (file) {
|
|
||||||
fclose(file);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::mutex mtx;
|
|
||||||
std::thread thrd;
|
|
||||||
std::condition_variable cv;
|
|
||||||
|
|
||||||
FILE * file;
|
|
||||||
|
|
||||||
bool prefix;
|
|
||||||
bool timestamps;
|
|
||||||
bool running;
|
|
||||||
|
|
||||||
int64_t t_start;
|
|
||||||
|
|
||||||
// ring buffer of entries
|
|
||||||
std::vector<common_log_entry> entries;
|
|
||||||
size_t head;
|
|
||||||
size_t tail;
|
|
||||||
|
|
||||||
// worker thread copies into this
|
|
||||||
common_log_entry cur;
|
|
||||||
|
|
||||||
public:
|
|
||||||
void add(enum ggml_log_level level, const char * fmt, va_list args) {
|
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
|
||||||
|
|
||||||
if (!running) {
|
|
||||||
// discard messages while the worker thread is paused
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto & entry = entries[tail];
|
|
||||||
|
|
||||||
{
|
|
||||||
// cannot use args twice, so make a copy in case we need to expand the buffer
|
|
||||||
va_list args_copy;
|
|
||||||
va_copy(args_copy, args);
|
|
||||||
|
|
||||||
#if 1
|
|
||||||
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
|
|
||||||
if (n >= entry.msg.size()) {
|
|
||||||
entry.msg.resize(n + 1);
|
|
||||||
vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
// hack for bolding arguments
|
|
||||||
|
|
||||||
std::stringstream ss;
|
|
||||||
for (int i = 0; fmt[i] != 0; i++) {
|
|
||||||
if (fmt[i] == '%') {
|
|
||||||
ss << LOG_COL_BOLD;
|
|
||||||
while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++];
|
|
||||||
ss << LOG_COL_DEFAULT;
|
|
||||||
if (fmt[i] == 0) break;
|
|
||||||
}
|
|
||||||
ss << fmt[i];
|
|
||||||
}
|
|
||||||
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args);
|
|
||||||
if (n >= entry.msg.size()) {
|
|
||||||
entry.msg.resize(n + 1);
|
|
||||||
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
va_end(args_copy);
|
|
||||||
}
|
|
||||||
|
|
||||||
entry.level = level;
|
|
||||||
entry.prefix = prefix;
|
|
||||||
entry.timestamp = 0;
|
|
||||||
if (timestamps) {
|
|
||||||
entry.timestamp = t_us() - t_start;
|
|
||||||
}
|
|
||||||
entry.is_end = false;
|
|
||||||
|
|
||||||
tail = (tail + 1) % entries.size();
|
|
||||||
if (tail == head) {
|
|
||||||
// expand the buffer
|
|
||||||
std::vector<common_log_entry> new_entries(2*entries.size());
|
|
||||||
|
|
||||||
size_t new_tail = 0;
|
|
||||||
|
|
||||||
do {
|
|
||||||
new_entries[new_tail] = std::move(entries[head]);
|
|
||||||
|
|
||||||
head = (head + 1) % entries.size();
|
|
||||||
new_tail = (new_tail + 1);
|
|
||||||
} while (head != tail);
|
|
||||||
|
|
||||||
head = 0;
|
|
||||||
tail = new_tail;
|
|
||||||
|
|
||||||
for (size_t i = tail; i < new_entries.size(); i++) {
|
|
||||||
new_entries[i].msg.resize(256);
|
|
||||||
}
|
|
||||||
|
|
||||||
entries = std::move(new_entries);
|
|
||||||
}
|
|
||||||
|
|
||||||
cv.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
void resume() {
|
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
|
||||||
|
|
||||||
if (running) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
running = true;
|
|
||||||
|
|
||||||
thrd = std::thread([this]() {
|
|
||||||
while (true) {
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> lock(mtx);
|
|
||||||
cv.wait(lock, [this]() { return head != tail; });
|
|
||||||
|
|
||||||
cur = entries[head];
|
|
||||||
|
|
||||||
head = (head + 1) % entries.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cur.is_end) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
cur.print(); // stdout and stderr
|
|
||||||
|
|
||||||
if (file) {
|
|
||||||
cur.print(file);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void pause() {
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
|
||||||
|
|
||||||
if (!running) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
running = false;
|
|
||||||
|
|
||||||
// push an entry to signal the worker thread to stop
|
|
||||||
{
|
|
||||||
auto & entry = entries[tail];
|
|
||||||
entry.is_end = true;
|
|
||||||
|
|
||||||
tail = (tail + 1) % entries.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
cv.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
thrd.join();
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_file(const char * path) {
|
|
||||||
pause();
|
|
||||||
|
|
||||||
if (file) {
|
|
||||||
fclose(file);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (path) {
|
|
||||||
file = fopen(path, "w");
|
|
||||||
} else {
|
|
||||||
file = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
resume();
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_colors(bool colors) {
|
|
||||||
pause();
|
|
||||||
|
|
||||||
if (colors) {
|
|
||||||
g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT;
|
|
||||||
g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD;
|
|
||||||
g_col[COMMON_LOG_COL_RED] = LOG_COL_RED;
|
|
||||||
g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN;
|
|
||||||
g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW;
|
|
||||||
g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE;
|
|
||||||
g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA;
|
|
||||||
g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
|
|
||||||
g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < g_col.size(); i++) {
|
|
||||||
g_col[i] = "";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resume();
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_prefix(bool prefix) {
|
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
|
||||||
|
|
||||||
this->prefix = prefix;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_timestamps(bool timestamps) {
|
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
|
||||||
|
|
||||||
this->timestamps = timestamps;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
|
||||||
// public API
|
|
||||||
//
|
|
||||||
|
|
||||||
struct common_log * common_log_init() {
|
|
||||||
return new common_log;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct common_log * common_log_main() {
|
|
||||||
static struct common_log log;
|
|
||||||
static std::once_flag init_flag;
|
|
||||||
std::call_once(init_flag, [&]() {
|
|
||||||
// Set default to auto-detect colors
|
|
||||||
log.set_colors(tty_can_use_colors());
|
|
||||||
});
|
|
||||||
|
|
||||||
return &log;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_pause(struct common_log * log) {
|
|
||||||
log->pause();
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_resume(struct common_log * log) {
|
|
||||||
log->resume();
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_free(struct common_log * log) {
|
|
||||||
delete log;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) {
|
|
||||||
va_list args;
|
|
||||||
va_start(args, fmt);
|
|
||||||
log->add(level, fmt, args);
|
|
||||||
va_end(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_set_file(struct common_log * log, const char * file) {
|
|
||||||
log->set_file(file);
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
|
||||||
if (colors == LOG_COLORS_AUTO) {
|
|
||||||
log->set_colors(tty_can_use_colors());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (colors == LOG_COLORS_DISABLED) {
|
|
||||||
log->set_colors(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(colors == LOG_COLORS_ENABLED);
|
|
||||||
log->set_colors(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_set_prefix(struct common_log * log, bool prefix) {
|
|
||||||
log->set_prefix(prefix);
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
|
||||||
log->set_timestamps(timestamps);
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_flush(struct common_log * log) {
|
|
||||||
log->pause();
|
|
||||||
log->resume();
|
|
||||||
}
|
|
||||||
|
|
||||||
static int common_get_verbosity(enum ggml_log_level level) {
|
|
||||||
switch (level) {
|
|
||||||
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
|
|
||||||
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
|
|
||||||
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
|
|
||||||
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
|
|
||||||
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
|
|
||||||
case GGML_LOG_LEVEL_NONE:
|
|
||||||
default:
|
|
||||||
return LOG_LEVEL_OUTPUT;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
|
||||||
auto verbosity = common_get_verbosity(level);
|
|
||||||
if (verbosity <= common_log_verbosity_thold) {
|
|
||||||
common_log_add(common_log_main(), level, "%s", text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h" // for ggml_log_level
|
|
||||||
|
|
||||||
#define LOG_CLR_TO_EOL "\033[K\r"
|
|
||||||
#define LOG_COL_DEFAULT "\033[0m"
|
|
||||||
#define LOG_COL_BOLD "\033[1m"
|
|
||||||
#define LOG_COL_RED "\033[31m"
|
|
||||||
#define LOG_COL_GREEN "\033[32m"
|
|
||||||
#define LOG_COL_YELLOW "\033[33m"
|
|
||||||
#define LOG_COL_BLUE "\033[34m"
|
|
||||||
#define LOG_COL_MAGENTA "\033[35m"
|
|
||||||
#define LOG_COL_CYAN "\033[36m"
|
|
||||||
#define LOG_COL_WHITE "\033[37m"
|
|
||||||
|
|
||||||
#ifndef __GNUC__
|
|
||||||
# define LOG_ATTRIBUTE_FORMAT(...)
|
|
||||||
#elif defined(__MINGW32__) && !defined(__clang__)
|
|
||||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
|
||||||
#else
|
|
||||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define LOG_LEVEL_DEBUG 4
|
|
||||||
#define LOG_LEVEL_INFO 3
|
|
||||||
#define LOG_LEVEL_WARN 2
|
|
||||||
#define LOG_LEVEL_ERROR 1
|
|
||||||
#define LOG_LEVEL_OUTPUT 0 // output data from tools
|
|
||||||
|
|
||||||
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
|
|
||||||
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
|
|
||||||
|
|
||||||
enum log_colors {
|
|
||||||
LOG_COLORS_AUTO = -1,
|
|
||||||
LOG_COLORS_DISABLED = 0,
|
|
||||||
LOG_COLORS_ENABLED = 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
|
|
||||||
// set via common_log_set_verbosity()
|
|
||||||
extern int common_log_verbosity_thold;
|
|
||||||
|
|
||||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
|
||||||
|
|
||||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
|
||||||
|
|
||||||
// the common_log uses an internal worker thread to print/write log messages
|
|
||||||
// when the worker thread is paused, incoming log messages are discarded
|
|
||||||
struct common_log;
|
|
||||||
|
|
||||||
struct common_log * common_log_init();
|
|
||||||
struct common_log * common_log_main(); // singleton, automatically destroys itself on exit
|
|
||||||
void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe
|
|
||||||
void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe
|
|
||||||
void common_log_free (struct common_log * log);
|
|
||||||
|
|
||||||
LOG_ATTRIBUTE_FORMAT(3, 4)
|
|
||||||
void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...);
|
|
||||||
|
|
||||||
// defaults: file = NULL, colors = false, prefix = false, timestamps = false
|
|
||||||
//
|
|
||||||
// regular log output:
|
|
||||||
//
|
|
||||||
// ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
|
|
||||||
// llm_load_tensors: ggml ctx size = 0.27 MiB
|
|
||||||
// llm_load_tensors: offloading 32 repeating layers to GPU
|
|
||||||
// llm_load_tensors: offloading non-repeating layers to GPU
|
|
||||||
//
|
|
||||||
// with prefix = true, timestamps = true, the log output will look like this:
|
|
||||||
//
|
|
||||||
// 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
|
|
||||||
// 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB
|
|
||||||
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
|
|
||||||
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
|
|
||||||
//
|
|
||||||
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
|
||||||
// I - info (stdout, V = LOG_DEFAULT_INFO)
|
|
||||||
// W - warning (stderr, V = LOG_DEFAULT_WARN)
|
|
||||||
// E - error (stderr, V = LOG_DEFAULT_ERROR)
|
|
||||||
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
|
|
||||||
//
|
|
||||||
|
|
||||||
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
|
||||||
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
|
|
||||||
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
|
||||||
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
|
||||||
void common_log_flush (struct common_log * log); // flush all pending log messages
|
|
||||||
|
|
||||||
// helper macros for logging
|
|
||||||
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
|
||||||
//
|
|
||||||
// for example:
|
|
||||||
//
|
|
||||||
// LOG_DBG("this is a debug message: %d\n", expensive_function());
|
|
||||||
//
|
|
||||||
// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold
|
|
||||||
//
|
|
||||||
|
|
||||||
#define LOG_TMPL(level, verbosity, ...) \
|
|
||||||
do { \
|
|
||||||
if ((verbosity) <= common_log_verbosity_thold) { \
|
|
||||||
common_log_add(common_log_main(), (level), __VA_ARGS__); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
|
|
||||||
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
|
|
||||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
|
|
||||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
|
|
||||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
||||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
|
|
||||||
|
|
||||||
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
|
||||||
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
|
||||||
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
|
|
||||||
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
|
|
||||||
#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
|
|
||||||
@@ -1,286 +0,0 @@
|
|||||||
#include "ngram-cache.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "log.h"
|
|
||||||
|
|
||||||
#include <cinttypes>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <fstream>
|
|
||||||
#include <thread>
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
void common_ngram_cache_update(common_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
|
|
||||||
std::vector<llama_token> & inp, int nnew, bool print_progress) {
|
|
||||||
const int64_t t_start_ms = ggml_time_ms();
|
|
||||||
const int64_t inp_size = inp.size();
|
|
||||||
|
|
||||||
const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1);
|
|
||||||
int64_t n_done = 0;
|
|
||||||
|
|
||||||
for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
|
|
||||||
const int64_t i_start = std::max(inp_size - nnew, ngram_size);
|
|
||||||
for (int64_t i = i_start; i < inp_size; ++i) {
|
|
||||||
const int64_t ngram_start = i - ngram_size;
|
|
||||||
common_ngram ngram(&inp[ngram_start], ngram_size);
|
|
||||||
const llama_token token = inp[i];
|
|
||||||
|
|
||||||
common_ngram_cache::iterator part_it = ngram_cache.find(ngram);
|
|
||||||
if (part_it == ngram_cache.end()) {
|
|
||||||
common_ngram_cache_part part;
|
|
||||||
part.emplace(token, 1);
|
|
||||||
ngram_cache.emplace(ngram, part);
|
|
||||||
} else {
|
|
||||||
common_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
|
|
||||||
if (token_count_it == part_it->second.end()) {
|
|
||||||
part_it->second.emplace(token, 1);
|
|
||||||
} else {
|
|
||||||
token_count_it->second++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
++n_done;
|
|
||||||
|
|
||||||
if (print_progress && n_done % 10000000 == 0) {
|
|
||||||
const int64_t t_now_ms = ggml_time_ms();
|
|
||||||
const int64_t eta_ms = (inp_size*(ngram_max-ngram_min+1) - n_done) * (t_now_ms - t_start_ms) / n_done;
|
|
||||||
const int64_t eta_min = eta_ms / (60*1000);
|
|
||||||
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: %" PRId64 "/%" PRId64 " done, ETA: %02" PRId64 ":%02" PRId64 "\n", __func__, n_done, n_todo, eta_min, eta_s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to get a token from the combined, speculative sequence of inp and draft.
|
|
||||||
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
|
|
||||||
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
|
|
||||||
}
|
|
||||||
|
|
||||||
// If sample size or percentage are below these thresholds the draft is aborted early:
|
|
||||||
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
|
|
||||||
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
|
|
||||||
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
|
|
||||||
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
|
|
||||||
|
|
||||||
// Helper function that tries to draft a token from only the static ngram cache:
|
|
||||||
static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) {
|
|
||||||
common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
|
||||||
if (part_static_it == nc_static.end()) {
|
|
||||||
return LLAMA_TOKEN_NULL;
|
|
||||||
}
|
|
||||||
const common_ngram_cache_part part_static = part_static_it->second;
|
|
||||||
|
|
||||||
int max_count_static = 0;
|
|
||||||
int sum_count_static = 0;
|
|
||||||
llama_token max_token = LLAMA_TOKEN_NULL;
|
|
||||||
|
|
||||||
for (std::pair<llama_token, int> token_count_static : part_static) {
|
|
||||||
const llama_token token = token_count_static.first;
|
|
||||||
const int32_t count_static = token_count_static.second;
|
|
||||||
|
|
||||||
if (count_static > max_count_static) {
|
|
||||||
max_token = token;
|
|
||||||
max_count_static = count_static;
|
|
||||||
}
|
|
||||||
sum_count_static += count_static;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
|
|
||||||
return LLAMA_TOKEN_NULL;
|
|
||||||
}
|
|
||||||
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
|
|
||||||
return LLAMA_TOKEN_NULL;
|
|
||||||
}
|
|
||||||
return max_token;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
|
|
||||||
static llama_token try_draft(
|
|
||||||
common_ngram_cache & nc_primary, const std::vector<common_ngram> & ngrams_primary, common_ngram_cache_part & part_static,
|
|
||||||
const int * min_sample_size, const int * min_percent) {
|
|
||||||
|
|
||||||
llama_token drafted_token = LLAMA_TOKEN_NULL;
|
|
||||||
|
|
||||||
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == LLAMA_TOKEN_NULL; --i) {
|
|
||||||
const common_ngram ngram_primary = ngrams_primary[i];
|
|
||||||
|
|
||||||
common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
|
|
||||||
if (part_primary_it == nc_primary.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const common_ngram_cache_part part_primary = part_primary_it->second;
|
|
||||||
|
|
||||||
int max_count_primary = 0;
|
|
||||||
int max_count_static = 0;
|
|
||||||
int sum_count_primary = 0;
|
|
||||||
llama_token max_token = LLAMA_TOKEN_NULL;
|
|
||||||
|
|
||||||
for (std::pair<llama_token, int> token_count_primary : part_primary) {
|
|
||||||
const llama_token token = token_count_primary.first;
|
|
||||||
|
|
||||||
common_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
|
|
||||||
|
|
||||||
const int32_t count_primary = token_count_primary.second;
|
|
||||||
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
|
|
||||||
|
|
||||||
if (count_primary*count_static > max_count_primary*max_count_static) {
|
|
||||||
max_token = token;
|
|
||||||
max_count_primary = count_primary;
|
|
||||||
max_count_static = count_static;
|
|
||||||
}
|
|
||||||
sum_count_primary += count_primary;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sum_count_primary < min_sample_size[i]) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
|
|
||||||
continue;;
|
|
||||||
}
|
|
||||||
drafted_token = max_token;
|
|
||||||
}
|
|
||||||
|
|
||||||
return drafted_token;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_ngram_cache_draft(
|
|
||||||
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
|
|
||||||
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
|
|
||||||
) {
|
|
||||||
GGML_ASSERT(draft.size() == 1);
|
|
||||||
const int inp_size = inp.size();
|
|
||||||
|
|
||||||
if (inp_size < LLAMA_NGRAM_STATIC) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
while ((int) draft.size()-1 < n_draft) {
|
|
||||||
llama_token drafted_token = LLAMA_TOKEN_NULL;
|
|
||||||
|
|
||||||
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
|
|
||||||
common_ngram ngram_static;
|
|
||||||
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
|
|
||||||
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
|
|
||||||
}
|
|
||||||
common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
|
||||||
common_ngram_cache_part part_static;
|
|
||||||
if (part_static_it != nc_static.end()) {
|
|
||||||
part_static = part_static_it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// cd = context + dynamic
|
|
||||||
std::vector<common_ngram> ngrams_cd;
|
|
||||||
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
|
|
||||||
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
|
|
||||||
common_ngram ngram_cd;
|
|
||||||
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
|
|
||||||
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
|
|
||||||
}
|
|
||||||
ngrams_cd.push_back(ngram_cd);
|
|
||||||
}
|
|
||||||
if (drafted_token == LLAMA_TOKEN_NULL) {
|
|
||||||
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
|
|
||||||
}
|
|
||||||
if (drafted_token == LLAMA_TOKEN_NULL) {
|
|
||||||
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
|
|
||||||
}
|
|
||||||
if (drafted_token == LLAMA_TOKEN_NULL) {
|
|
||||||
drafted_token = try_draft(nc_static, ngram_static);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (drafted_token == LLAMA_TOKEN_NULL) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG(" - draft candidate: token=%d\n", drafted_token);
|
|
||||||
draft.push_back(drafted_token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
|
|
||||||
std::ofstream file_out(filename, std::ios::binary);
|
|
||||||
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
|
|
||||||
const common_ngram ngram = item.first;
|
|
||||||
common_ngram_cache_part token_counts = item.second;
|
|
||||||
GGML_ASSERT(!token_counts.empty());
|
|
||||||
const int32_t ntokens = token_counts.size();
|
|
||||||
GGML_ASSERT(ntokens > 0);
|
|
||||||
|
|
||||||
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(common_ngram));
|
|
||||||
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
|
|
||||||
for (std::pair<llama_token, int32_t> item2 : token_counts) {
|
|
||||||
const llama_token token = item2.first;
|
|
||||||
const int32_t count = item2.second;
|
|
||||||
GGML_ASSERT(count > 0);
|
|
||||||
|
|
||||||
file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
|
|
||||||
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
common_ngram_cache common_ngram_cache_load(std::string & filename) {
|
|
||||||
std::ifstream hashmap_file(filename, std::ios::binary);
|
|
||||||
if (!hashmap_file) {
|
|
||||||
throw std::ifstream::failure("Unable to open file " + filename);
|
|
||||||
}
|
|
||||||
common_ngram_cache ngram_cache;
|
|
||||||
|
|
||||||
common_ngram ngram;
|
|
||||||
int32_t ntokens;
|
|
||||||
llama_token token;
|
|
||||||
int32_t count;
|
|
||||||
|
|
||||||
char * ngramc = reinterpret_cast<char*>(&ngram);
|
|
||||||
char * ntokensc = reinterpret_cast<char*>(&ntokens);
|
|
||||||
char * tokenc = reinterpret_cast<char*>(&token);
|
|
||||||
char * countc = reinterpret_cast<char*>(&count);
|
|
||||||
while(hashmap_file.read(ngramc, sizeof(common_ngram))) {
|
|
||||||
GGML_ASSERT(!hashmap_file.eof());
|
|
||||||
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
|
|
||||||
GGML_ASSERT(ntokens > 0);
|
|
||||||
common_ngram_cache_part token_counts;
|
|
||||||
|
|
||||||
for (int i = 0; i < ntokens; ++i) {
|
|
||||||
GGML_ASSERT(!hashmap_file.eof());
|
|
||||||
GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
|
|
||||||
GGML_ASSERT(!hashmap_file.eof());
|
|
||||||
GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
|
|
||||||
GGML_ASSERT(count > 0);
|
|
||||||
token_counts.emplace(token, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
ngram_cache.emplace(ngram, token_counts);
|
|
||||||
}
|
|
||||||
GGML_ASSERT(hashmap_file.eof());
|
|
||||||
|
|
||||||
return ngram_cache;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) {
|
|
||||||
for (std::pair<common_ngram, common_ngram_cache_part> ngram_part : ngram_cache_add) {
|
|
||||||
const common_ngram ngram = ngram_part.first;
|
|
||||||
common_ngram_cache_part part = ngram_part.second;
|
|
||||||
|
|
||||||
common_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);
|
|
||||||
if (part_merged_it == ngram_cache_target.end()) {
|
|
||||||
ngram_cache_target.emplace(ngram, part);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (std::pair<llama_token, int32_t> token_count : part) {
|
|
||||||
const llama_token token = token_count.first;
|
|
||||||
const int32_t count = token_count.second;
|
|
||||||
GGML_ASSERT(count > 0);
|
|
||||||
|
|
||||||
common_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
|
|
||||||
if (token_count_merged_it == part_merged_it->second.end()) {
|
|
||||||
part_merged_it->second.emplace(token, count);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
token_count_merged_it->second += count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "llama.h"
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#define LLAMA_NGRAM_MIN 1
|
|
||||||
#define LLAMA_NGRAM_MAX 4
|
|
||||||
#define LLAMA_NGRAM_STATIC 2
|
|
||||||
|
|
||||||
// Data structures to map n-grams to empirical token probabilities:
|
|
||||||
|
|
||||||
struct common_ngram {
|
|
||||||
llama_token tokens[LLAMA_NGRAM_MAX];
|
|
||||||
|
|
||||||
common_ngram() {
|
|
||||||
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
|
||||||
tokens[i] = LLAMA_TOKEN_NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
common_ngram(const llama_token * input, const int ngram_size) {
|
|
||||||
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
|
||||||
tokens[i] = i < ngram_size ? input[i] : LLAMA_TOKEN_NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator==(const common_ngram & other) const {
|
|
||||||
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
|
||||||
if (tokens[i] != other.tokens[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_token_hash_function {
|
|
||||||
size_t operator()(const llama_token token) const {
|
|
||||||
// see https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/
|
|
||||||
return token * 11400714819323198485llu;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_ngram_hash_function {
|
|
||||||
size_t operator()(const common_ngram & ngram) const {
|
|
||||||
size_t hash = common_token_hash_function{}(ngram.tokens[0]);
|
|
||||||
for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) {
|
|
||||||
hash ^= common_token_hash_function{}(ngram.tokens[i]);
|
|
||||||
}
|
|
||||||
return hash;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// token -> number of times token has been seen
|
|
||||||
typedef std::unordered_map<llama_token, int32_t> common_ngram_cache_part;
|
|
||||||
|
|
||||||
// n-gram -> empirical distribution of following tokens
|
|
||||||
typedef std::unordered_map<common_ngram, common_ngram_cache_part, common_ngram_hash_function> common_ngram_cache;
|
|
||||||
|
|
||||||
|
|
||||||
// Update an ngram cache with tokens.
|
|
||||||
// ngram_cache: the cache to modify.
|
|
||||||
// ngram_min/ngram_max: the min/max size of the ngrams to extract from inp_data.
|
|
||||||
// inp_data: the token sequence with which to update ngram_cache.
|
|
||||||
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
|
|
||||||
// print_progress: whether to print progress to stderr.
|
|
||||||
//
|
|
||||||
// In order to get correct results inp_data can ONLY BE APPENDED TO.
|
|
||||||
// Changes in the middle need a complete rebuild.
|
|
||||||
void common_ngram_cache_update(
|
|
||||||
common_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector<llama_token> & inp_data, int nnew, bool print_progress);
|
|
||||||
|
|
||||||
// Try to draft tokens from ngram caches.
|
|
||||||
// inp: the tokens generated so far.
|
|
||||||
// draft: the token sequence to draft. Expected to initially contain the previously sampled token.
|
|
||||||
// n_draft: maximum number of tokens to add to draft.
|
|
||||||
// ngram_min/gram_max: the min/max size of the ngrams in nc_context and nc_dynamic.
|
|
||||||
// nc_context: ngram cache based on current context.
|
|
||||||
// nc_dynamic: ngram cache based on previous user generations.
|
|
||||||
// nc_static: ngram cache generated from a large text corpus, used for validation.
|
|
||||||
void common_ngram_cache_draft(
|
|
||||||
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
|
|
||||||
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static);
|
|
||||||
|
|
||||||
// Save an ngram cache to a file.
|
|
||||||
// ngram_cache: the ngram cache to save.
|
|
||||||
// filename: the path under which to save the ngram cache.
|
|
||||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
|
|
||||||
|
|
||||||
// Load an ngram cache saved with common_ngram_cache_save.
|
|
||||||
// filename: the path from which to load the ngram cache.
|
|
||||||
// returns: an ngram cache containing the information saved to filename.
|
|
||||||
common_ngram_cache common_ngram_cache_load(std::string & filename);
|
|
||||||
|
|
||||||
// Merge two ngram caches.
|
|
||||||
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
|
|
||||||
// ngram_cache_add: the ngram cache to add to ngram_cache_target.
|
|
||||||
void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,459 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <nlohmann/json_fwd.hpp>
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
|
||||||
#include <functional>
|
|
||||||
#include <vector>
|
|
||||||
#include <variant>
|
|
||||||
|
|
||||||
struct common_grammar_builder;
|
|
||||||
|
|
||||||
class common_peg_parser_builder;
|
|
||||||
|
|
||||||
using common_peg_parser_id = size_t;
|
|
||||||
constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast<common_peg_parser_id>(-1);
|
|
||||||
|
|
||||||
using common_peg_ast_id = size_t;
|
|
||||||
constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast<common_peg_ast_id>(-1);
|
|
||||||
|
|
||||||
// Lightweight wrapper around common_peg_parser_id for convenience
|
|
||||||
class common_peg_parser {
|
|
||||||
common_peg_parser_id id_;
|
|
||||||
common_peg_parser_builder & builder_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {}
|
|
||||||
common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {}
|
|
||||||
|
|
||||||
common_peg_parser & operator=(const common_peg_parser & other);
|
|
||||||
common_peg_parser & operator+=(const common_peg_parser & other);
|
|
||||||
common_peg_parser & operator|=(const common_peg_parser & other);
|
|
||||||
|
|
||||||
operator common_peg_parser_id() const { return id_; }
|
|
||||||
common_peg_parser_id id() const { return id_; }
|
|
||||||
|
|
||||||
common_peg_parser_builder & builder() const { return builder_; }
|
|
||||||
|
|
||||||
// Creates a sequence
|
|
||||||
common_peg_parser operator+(const common_peg_parser & other) const;
|
|
||||||
|
|
||||||
// Creates a sequence separated by spaces.
|
|
||||||
common_peg_parser operator<<(const common_peg_parser & other) const;
|
|
||||||
|
|
||||||
// Creates a choice
|
|
||||||
common_peg_parser operator|(const common_peg_parser & other) const;
|
|
||||||
|
|
||||||
common_peg_parser operator+(const char * str) const;
|
|
||||||
common_peg_parser operator+(const std::string & str) const;
|
|
||||||
common_peg_parser operator<<(const char * str) const;
|
|
||||||
common_peg_parser operator<<(const std::string & str) const;
|
|
||||||
common_peg_parser operator|(const char * str) const;
|
|
||||||
common_peg_parser operator|(const std::string & str) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
common_peg_parser operator+(const char * str, const common_peg_parser & p);
|
|
||||||
common_peg_parser operator+(const std::string & str, const common_peg_parser & p);
|
|
||||||
common_peg_parser operator<<(const char * str, const common_peg_parser & p);
|
|
||||||
common_peg_parser operator<<(const std::string & str, const common_peg_parser & p);
|
|
||||||
common_peg_parser operator|(const char * str, const common_peg_parser & p);
|
|
||||||
common_peg_parser operator|(const std::string & str, const common_peg_parser & p);
|
|
||||||
|
|
||||||
enum common_peg_parse_result_type {
|
|
||||||
COMMON_PEG_PARSE_RESULT_FAIL = 0,
|
|
||||||
COMMON_PEG_PARSE_RESULT_SUCCESS = 1,
|
|
||||||
COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
const char * common_peg_parse_result_type_name(common_peg_parse_result_type type);
|
|
||||||
|
|
||||||
struct common_peg_ast_node {
|
|
||||||
common_peg_ast_id id;
|
|
||||||
std::string rule;
|
|
||||||
std::string tag;
|
|
||||||
size_t start;
|
|
||||||
size_t end;
|
|
||||||
std::string_view text;
|
|
||||||
std::vector<common_peg_ast_id> children;
|
|
||||||
|
|
||||||
bool is_partial = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_parse_result;
|
|
||||||
|
|
||||||
using common_peg_ast_visitor = std::function<void(const common_peg_ast_node & node)>;
|
|
||||||
|
|
||||||
class common_peg_ast_arena {
|
|
||||||
std::vector<common_peg_ast_node> nodes_;
|
|
||||||
public:
|
|
||||||
common_peg_ast_id add_node(
|
|
||||||
const std::string & rule,
|
|
||||||
const std::string & tag,
|
|
||||||
size_t start,
|
|
||||||
size_t end,
|
|
||||||
std::string_view text,
|
|
||||||
std::vector<common_peg_ast_id> children,
|
|
||||||
bool is_partial = false
|
|
||||||
) {
|
|
||||||
common_peg_ast_id id = nodes_.size();
|
|
||||||
nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial});
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); }
|
|
||||||
|
|
||||||
size_t size() const { return nodes_.size(); }
|
|
||||||
|
|
||||||
void clear() { nodes_.clear(); }
|
|
||||||
|
|
||||||
void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const;
|
|
||||||
void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_parse_result {
|
|
||||||
common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL;
|
|
||||||
size_t start = 0;
|
|
||||||
size_t end = 0;
|
|
||||||
|
|
||||||
std::vector<common_peg_ast_id> nodes;
|
|
||||||
|
|
||||||
common_peg_parse_result() = default;
|
|
||||||
|
|
||||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start)
|
|
||||||
: type(type), start(start), end(start) {}
|
|
||||||
|
|
||||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end)
|
|
||||||
: type(type), start(start), end(end) {}
|
|
||||||
|
|
||||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector<common_peg_ast_id> nodes)
|
|
||||||
: type(type), start(start), end(end), nodes(std::move(nodes)) {}
|
|
||||||
|
|
||||||
bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; }
|
|
||||||
bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; }
|
|
||||||
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_parse_context {
|
|
||||||
std::string input;
|
|
||||||
bool is_partial;
|
|
||||||
common_peg_ast_arena ast;
|
|
||||||
|
|
||||||
int parse_depth;
|
|
||||||
|
|
||||||
common_peg_parse_context()
|
|
||||||
: is_partial(false), parse_depth(0) {}
|
|
||||||
|
|
||||||
common_peg_parse_context(const std::string & input)
|
|
||||||
: input(input), is_partial(false), parse_depth(0) {}
|
|
||||||
|
|
||||||
common_peg_parse_context(const std::string & input, bool is_partial)
|
|
||||||
: input(input), is_partial(is_partial), parse_depth(0) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
class common_peg_arena;
|
|
||||||
|
|
||||||
// Parser variants
|
|
||||||
struct common_peg_epsilon_parser {};
|
|
||||||
|
|
||||||
struct common_peg_start_parser {};
|
|
||||||
|
|
||||||
struct common_peg_end_parser {};
|
|
||||||
|
|
||||||
struct common_peg_literal_parser {
|
|
||||||
std::string literal;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_sequence_parser {
|
|
||||||
std::vector<common_peg_parser_id> children;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_choice_parser {
|
|
||||||
std::vector<common_peg_parser_id> children;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_repetition_parser {
|
|
||||||
common_peg_parser_id child;
|
|
||||||
int min_count;
|
|
||||||
int max_count; // -1 for unbounded
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_and_parser {
|
|
||||||
common_peg_parser_id child;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_not_parser {
|
|
||||||
common_peg_parser_id child;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_any_parser {};
|
|
||||||
|
|
||||||
struct common_peg_space_parser {};
|
|
||||||
|
|
||||||
struct common_peg_chars_parser {
|
|
||||||
struct char_range {
|
|
||||||
uint32_t start;
|
|
||||||
uint32_t end;
|
|
||||||
bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; }
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string pattern;
|
|
||||||
std::vector<char_range> ranges;
|
|
||||||
bool negated;
|
|
||||||
int min_count;
|
|
||||||
int max_count; // -1 for unbounded
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_json_string_parser {};
|
|
||||||
|
|
||||||
struct common_peg_until_parser {
|
|
||||||
std::vector<std::string> delimiters;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_schema_parser {
|
|
||||||
common_peg_parser_id child;
|
|
||||||
std::string name;
|
|
||||||
std::shared_ptr<nlohmann::ordered_json> schema;
|
|
||||||
|
|
||||||
// Indicates if the GBNF should accept a raw string that matches the schema.
|
|
||||||
bool raw;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_rule_parser {
|
|
||||||
std::string name;
|
|
||||||
common_peg_parser_id child;
|
|
||||||
bool trigger;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_ref_parser {
|
|
||||||
std::string name;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_atomic_parser {
|
|
||||||
common_peg_parser_id child;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_peg_tag_parser {
|
|
||||||
common_peg_parser_id child;
|
|
||||||
std::string tag;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Variant holding all parser types
|
|
||||||
using common_peg_parser_variant = std::variant<
|
|
||||||
common_peg_epsilon_parser,
|
|
||||||
common_peg_start_parser,
|
|
||||||
common_peg_end_parser,
|
|
||||||
common_peg_literal_parser,
|
|
||||||
common_peg_sequence_parser,
|
|
||||||
common_peg_choice_parser,
|
|
||||||
common_peg_repetition_parser,
|
|
||||||
common_peg_and_parser,
|
|
||||||
common_peg_not_parser,
|
|
||||||
common_peg_any_parser,
|
|
||||||
common_peg_space_parser,
|
|
||||||
common_peg_chars_parser,
|
|
||||||
common_peg_json_string_parser,
|
|
||||||
common_peg_until_parser,
|
|
||||||
common_peg_schema_parser,
|
|
||||||
common_peg_rule_parser,
|
|
||||||
common_peg_ref_parser,
|
|
||||||
common_peg_atomic_parser,
|
|
||||||
common_peg_tag_parser
|
|
||||||
>;
|
|
||||||
|
|
||||||
class common_peg_arena {
|
|
||||||
std::vector<common_peg_parser_variant> parsers_;
|
|
||||||
std::unordered_map<std::string, common_peg_parser_id> rules_;
|
|
||||||
common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); }
|
|
||||||
common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); }
|
|
||||||
|
|
||||||
size_t size() const { return parsers_.size(); }
|
|
||||||
bool empty() const { return parsers_.empty(); }
|
|
||||||
|
|
||||||
common_peg_parser_id get_rule(const std::string & name) const;
|
|
||||||
bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); }
|
|
||||||
|
|
||||||
common_peg_parser_id root() const { return root_; }
|
|
||||||
void set_root(common_peg_parser_id id) { root_ = id; }
|
|
||||||
|
|
||||||
common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const;
|
|
||||||
common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const;
|
|
||||||
|
|
||||||
void resolve_refs();
|
|
||||||
|
|
||||||
void build_grammar(const common_grammar_builder & builder, bool lazy = false) const;
|
|
||||||
|
|
||||||
std::string dump(common_peg_parser_id id) const;
|
|
||||||
|
|
||||||
nlohmann::json to_json() const;
|
|
||||||
static common_peg_arena from_json(const nlohmann::json & j);
|
|
||||||
|
|
||||||
std::string save() const;
|
|
||||||
void load(const std::string & data);
|
|
||||||
|
|
||||||
friend class common_peg_parser_builder;
|
|
||||||
|
|
||||||
private:
|
|
||||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
|
||||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
|
||||||
|
|
||||||
common_peg_parser_id resolve_ref(common_peg_parser_id id);
|
|
||||||
};
|
|
||||||
|
|
||||||
class common_peg_parser_builder {
|
|
||||||
common_peg_arena arena_;
|
|
||||||
|
|
||||||
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
|
|
||||||
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
|
|
||||||
|
|
||||||
public:
|
|
||||||
common_peg_parser_builder();
|
|
||||||
|
|
||||||
// Match nothing, always succeed.
|
|
||||||
// S -> ε
|
|
||||||
common_peg_parser eps() { return add(common_peg_epsilon_parser{}); }
|
|
||||||
|
|
||||||
// Matches the start of the input.
|
|
||||||
// S -> ^
|
|
||||||
common_peg_parser start() { return add(common_peg_start_parser{}); }
|
|
||||||
|
|
||||||
// Matches the end of the input.
|
|
||||||
// S -> $
|
|
||||||
common_peg_parser end() { return add(common_peg_end_parser{}); }
|
|
||||||
|
|
||||||
// Matches an exact literal string.
|
|
||||||
// S -> "hello"
|
|
||||||
common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); }
|
|
||||||
|
|
||||||
// Matches a sequence of parsers in order, all must succeed.
|
|
||||||
// S -> A B C
|
|
||||||
common_peg_parser sequence() { return add(common_peg_sequence_parser{}); }
|
|
||||||
common_peg_parser sequence(const std::vector<common_peg_parser_id> & parsers);
|
|
||||||
common_peg_parser sequence(const std::vector<common_peg_parser> & parsers);
|
|
||||||
common_peg_parser sequence(std::initializer_list<common_peg_parser> parsers);
|
|
||||||
|
|
||||||
// Matches the first parser that succeeds from a list of alternatives.
|
|
||||||
// S -> A | B | C
|
|
||||||
common_peg_parser choice() { return add(common_peg_choice_parser{}); }
|
|
||||||
common_peg_parser choice(const std::vector<common_peg_parser_id> & parsers);
|
|
||||||
common_peg_parser choice(const std::vector<common_peg_parser> & parsers);
|
|
||||||
common_peg_parser choice(std::initializer_list<common_peg_parser> parsers);
|
|
||||||
|
|
||||||
// Matches one or more repetitions of a parser.
|
|
||||||
// S -> A+
|
|
||||||
common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); }
|
|
||||||
|
|
||||||
// Matches zero or more repetitions of a parser, always succeeds.
|
|
||||||
// S -> A*
|
|
||||||
common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); }
|
|
||||||
|
|
||||||
// Matches zero or one occurrence of a parser, always succeeds.
|
|
||||||
// S -> A?
|
|
||||||
common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); }
|
|
||||||
|
|
||||||
// Positive lookahead: succeeds if child parser succeeds, consumes no input.
|
|
||||||
// S -> &A
|
|
||||||
common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); }
|
|
||||||
|
|
||||||
// Negative lookahead: succeeds if child parser fails, consumes no input.
|
|
||||||
// S -> !A
|
|
||||||
common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); }
|
|
||||||
|
|
||||||
// Matches any single character.
|
|
||||||
// S -> .
|
|
||||||
common_peg_parser any() { return add(common_peg_any_parser{}); }
|
|
||||||
|
|
||||||
// Matches between min and max repetitions of characters from a character class.
|
|
||||||
// S -> [a-z]{m,n}
|
|
||||||
//
|
|
||||||
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
|
||||||
common_peg_parser chars(const std::string & classes, int min = 1, int max = -1);
|
|
||||||
|
|
||||||
// Creates a lightweight reference to a named rule (resolved during build()).
|
|
||||||
// Use this for forward references in recursive grammars.
|
|
||||||
// expr_ref -> expr
|
|
||||||
common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); }
|
|
||||||
|
|
||||||
// Matches zero or more whitespace characters (space, tab, newline).
|
|
||||||
// S -> [ \t\n]*
|
|
||||||
common_peg_parser space() { return add(common_peg_space_parser{}); }
|
|
||||||
|
|
||||||
// Matches all characters until a delimiter is found (delimiter not consumed).
|
|
||||||
// S -> (!delim .)*
|
|
||||||
common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); }
|
|
||||||
|
|
||||||
// Matches all characters until one of the delimiters in the list is found (delimiter not consumed).
|
|
||||||
// S -> (!delim .)*
|
|
||||||
common_peg_parser until_one_of(const std::vector<std::string> & delimiters) { return add(common_peg_until_parser{delimiters}); }
|
|
||||||
|
|
||||||
// Matches everything
|
|
||||||
// S -> .*
|
|
||||||
common_peg_parser rest() { return until_one_of({}); }
|
|
||||||
|
|
||||||
// Matches between min and max repetitions of a parser (inclusive).
|
|
||||||
// S -> A{m,n}
|
|
||||||
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
|
||||||
common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); }
|
|
||||||
|
|
||||||
// Matches exactly n repetitions of a parser.
|
|
||||||
// S -> A{n}
|
|
||||||
common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); }
|
|
||||||
|
|
||||||
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
|
|
||||||
// value -> object | array | string | number | true | false | null
|
|
||||||
common_peg_parser json();
|
|
||||||
common_peg_parser json_object();
|
|
||||||
common_peg_parser json_string();
|
|
||||||
common_peg_parser json_array();
|
|
||||||
common_peg_parser json_number();
|
|
||||||
common_peg_parser json_bool();
|
|
||||||
common_peg_parser json_null();
|
|
||||||
|
|
||||||
// Matches JSON string content without the surrounding quotes.
|
|
||||||
// Useful for extracting content within a JSON string.
|
|
||||||
common_peg_parser json_string_content();
|
|
||||||
|
|
||||||
// Matches a JSON object member with a key and associated parser as the
|
|
||||||
// value.
|
|
||||||
common_peg_parser json_member(const std::string & key, const common_peg_parser & p);
|
|
||||||
|
|
||||||
// Wraps a parser with JSON schema metadata for grammar generation.
|
|
||||||
// Used internally to convert JSON schemas to GBNF grammar rules.
|
|
||||||
common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false);
|
|
||||||
|
|
||||||
// Creates a named rule, stores it in the grammar, and returns a ref.
|
|
||||||
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
|
||||||
// auto json = p.rule("json", json_obj | json_arr | ...)
|
|
||||||
common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false);
|
|
||||||
|
|
||||||
// Creates a named rule using a builder function, and returns a ref.
|
|
||||||
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
|
||||||
// auto json = p.rule("json", [&]() { return json_object() | json_array() | ... })
|
|
||||||
common_peg_parser rule(const std::string & name, const std::function<common_peg_parser()> & builder, bool trigger = false);
|
|
||||||
|
|
||||||
// Creates a trigger rule. When generating a lazy grammar from the parser,
|
|
||||||
// only trigger rules and descendents are emitted.
|
|
||||||
common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); }
|
|
||||||
common_peg_parser trigger_rule(const std::string & name, const std::function<common_peg_parser()> & builder) { return rule(name, builder, true); }
|
|
||||||
|
|
||||||
// Creates an atomic parser. Atomic parsers do not create an AST node if
|
|
||||||
// the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is
|
|
||||||
// intended for situations where partial output is undesirable.
|
|
||||||
common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); }
|
|
||||||
|
|
||||||
// Tags create nodes in the generated AST for semantic purposes.
|
|
||||||
// Unlike rules, you can tag multiple nodes with the same tag.
|
|
||||||
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
|
||||||
|
|
||||||
void set_root(const common_peg_parser & p);
|
|
||||||
|
|
||||||
common_peg_arena build();
|
|
||||||
};
|
|
||||||
|
|
||||||
// Helper function for building parsers
|
|
||||||
common_peg_arena build_peg_parser(const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn);
|
|
||||||
@@ -1,483 +0,0 @@
|
|||||||
#include "arg.h"
|
|
||||||
#include "preset.h"
|
|
||||||
#include "peg-parser.h"
|
|
||||||
#include "log.h"
|
|
||||||
#include "download.h"
|
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <filesystem>
|
|
||||||
|
|
||||||
static std::string rm_leading_dashes(const std::string & str) {
|
|
||||||
size_t pos = 0;
|
|
||||||
while (pos < str.size() && str[pos] == '-') {
|
|
||||||
++pos;
|
|
||||||
}
|
|
||||||
return str.substr(pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
// only allow a subset of args for remote presets for security reasons
|
|
||||||
// do not add more args unless absolutely necessary
|
|
||||||
// args that output to files are strictly prohibited
|
|
||||||
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
|
|
||||||
static const std::set<std::string> allowed_options = {
|
|
||||||
"model-url",
|
|
||||||
"hf-repo",
|
|
||||||
"hf-repo-draft",
|
|
||||||
"hf-repo-v", // vocoder
|
|
||||||
"hf-file-v", // vocoder
|
|
||||||
"mmproj-url",
|
|
||||||
"pooling",
|
|
||||||
"jinja",
|
|
||||||
"batch-size",
|
|
||||||
"ubatch-size",
|
|
||||||
"cache-reuse",
|
|
||||||
"chat-template-kwargs",
|
|
||||||
"mmap",
|
|
||||||
// note: sampling params are automatically allowed by default
|
|
||||||
// negated args will be added automatically if the positive arg is specified above
|
|
||||||
};
|
|
||||||
|
|
||||||
std::set<std::string> allowed_keys;
|
|
||||||
|
|
||||||
for (const auto & it : key_to_opt) {
|
|
||||||
const std::string & key = it.first;
|
|
||||||
const common_arg & opt = it.second;
|
|
||||||
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
|
|
||||||
allowed_keys.insert(key);
|
|
||||||
// also add variant keys (args without leading dashes and env vars)
|
|
||||||
for (const auto & arg : opt.get_args()) {
|
|
||||||
allowed_keys.insert(rm_leading_dashes(arg));
|
|
||||||
}
|
|
||||||
for (const auto & env : opt.get_env()) {
|
|
||||||
allowed_keys.insert(env);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return allowed_keys;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
|
|
||||||
std::vector<std::string> args;
|
|
||||||
|
|
||||||
if (!bin_path.empty()) {
|
|
||||||
args.push_back(bin_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto & [opt, value] : options) {
|
|
||||||
if (opt.is_preset_only) {
|
|
||||||
continue; // skip preset-only options (they are not CLI args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// use the last arg as the main arg (i.e. --long-form)
|
|
||||||
args.push_back(opt.args.back());
|
|
||||||
|
|
||||||
// handle value(s)
|
|
||||||
if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
|
|
||||||
// flag option, no value
|
|
||||||
if (common_arg_utils::is_falsey(value)) {
|
|
||||||
// use negative arg if available
|
|
||||||
if (!opt.args_neg.empty()) {
|
|
||||||
args.back() = opt.args_neg.back();
|
|
||||||
} else {
|
|
||||||
// otherwise, skip the flag
|
|
||||||
// TODO: maybe throw an error instead?
|
|
||||||
args.pop_back();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (opt.value_hint != nullptr) {
|
|
||||||
// single value
|
|
||||||
args.push_back(value);
|
|
||||||
}
|
|
||||||
if (opt.value_hint != nullptr && opt.value_hint_2 != nullptr) {
|
|
||||||
throw std::runtime_error(string_format(
|
|
||||||
"common_preset::to_args(): option '%s' has two values, which is not supported yet",
|
|
||||||
opt.args.back()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return args;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_preset::to_ini() const {
|
|
||||||
std::ostringstream ss;
|
|
||||||
|
|
||||||
ss << "[" << name << "]\n";
|
|
||||||
for (const auto & [opt, value] : options) {
|
|
||||||
auto espaced_value = value;
|
|
||||||
string_replace_all(espaced_value, "\n", "\\\n");
|
|
||||||
ss << rm_leading_dashes(opt.args.back()) << " = ";
|
|
||||||
ss << espaced_value << "\n";
|
|
||||||
}
|
|
||||||
ss << "\n";
|
|
||||||
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_preset::set_option(const common_preset_context & ctx, const std::string & env, const std::string & value) {
|
|
||||||
// try if option exists, update it
|
|
||||||
for (auto & [opt, val] : options) {
|
|
||||||
if (opt.env && env == opt.env) {
|
|
||||||
val = value;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// if option does not exist, we need to add it
|
|
||||||
if (ctx.key_to_opt.find(env) == ctx.key_to_opt.end()) {
|
|
||||||
throw std::runtime_error(string_format(
|
|
||||||
"%s: option with env '%s' not found in ctx_params",
|
|
||||||
__func__, env.c_str()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
options[ctx.key_to_opt.at(env)] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_preset::unset_option(const std::string & env) {
|
|
||||||
for (auto it = options.begin(); it != options.end(); ) {
|
|
||||||
const common_arg & opt = it->first;
|
|
||||||
if (opt.env && env == opt.env) {
|
|
||||||
it = options.erase(it);
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool common_preset::get_option(const std::string & env, std::string & value) const {
|
|
||||||
for (const auto & [opt, val] : options) {
|
|
||||||
if (opt.env && env == opt.env) {
|
|
||||||
value = val;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_preset::merge(const common_preset & other) {
|
|
||||||
for (const auto & [opt, val] : other.options) {
|
|
||||||
options[opt] = val; // overwrite existing options
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_preset::apply_to_params(common_params & params) const {
|
|
||||||
for (const auto & [opt, val] : options) {
|
|
||||||
// apply each option to params
|
|
||||||
if (opt.handler_string) {
|
|
||||||
opt.handler_string(params, val);
|
|
||||||
} else if (opt.handler_int) {
|
|
||||||
opt.handler_int(params, std::stoi(val));
|
|
||||||
} else if (opt.handler_bool) {
|
|
||||||
opt.handler_bool(params, common_arg_utils::is_truthy(val));
|
|
||||||
} else if (opt.handler_str_str) {
|
|
||||||
// not supported yet
|
|
||||||
throw std::runtime_error(string_format(
|
|
||||||
"%s: option with two values is not supported yet",
|
|
||||||
__func__
|
|
||||||
));
|
|
||||||
} else if (opt.handler_void) {
|
|
||||||
opt.handler_void(params);
|
|
||||||
} else {
|
|
||||||
GGML_ABORT("unknown handler type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
|
|
||||||
std::map<std::string, std::map<std::string, std::string>> parsed;
|
|
||||||
|
|
||||||
if (!std::filesystem::exists(path)) {
|
|
||||||
throw std::runtime_error("preset file does not exist: " + path);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::ifstream file(path);
|
|
||||||
if (!file.good()) {
|
|
||||||
throw std::runtime_error("failed to open server preset file: " + path);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string contents((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
|
||||||
|
|
||||||
static const auto parser = build_peg_parser([](auto & p) {
|
|
||||||
// newline ::= "\r\n" / "\n" / "\r"
|
|
||||||
auto newline = p.rule("newline", p.literal("\r\n") | p.literal("\n") | p.literal("\r"));
|
|
||||||
|
|
||||||
// ws ::= [ \t]*
|
|
||||||
auto ws = p.rule("ws", p.chars("[ \t]", 0, -1));
|
|
||||||
|
|
||||||
// comment ::= [;#] (!newline .)*
|
|
||||||
auto comment = p.rule("comment", p.chars("[;#]", 1, 1) + p.zero_or_more(p.negate(newline) + p.any()));
|
|
||||||
|
|
||||||
// eol ::= ws comment? (newline / EOF)
|
|
||||||
auto eol = p.rule("eol", ws + p.optional(comment) + (newline | p.end()));
|
|
||||||
|
|
||||||
// ident ::= [a-zA-Z_] [a-zA-Z0-9_.-]*
|
|
||||||
auto ident = p.rule("ident", p.chars("[a-zA-Z_]", 1, 1) + p.chars("[a-zA-Z0-9_.-]", 0, -1));
|
|
||||||
|
|
||||||
// value ::= (!eol-start .)*
|
|
||||||
auto eol_start = p.rule("eol-start", ws + (p.chars("[;#]", 1, 1) | newline | p.end()));
|
|
||||||
auto value = p.rule("value", p.zero_or_more(p.negate(eol_start) + p.any()));
|
|
||||||
|
|
||||||
// header-line ::= "[" ws ident ws "]" eol
|
|
||||||
auto header_line = p.rule("header-line", "[" + ws + p.tag("section-name", p.chars("[^]]")) + ws + "]" + eol);
|
|
||||||
|
|
||||||
// kv-line ::= ident ws "=" ws value eol
|
|
||||||
auto kv_line = p.rule("kv-line", p.tag("key", ident) + ws + "=" + ws + p.tag("value", value) + eol);
|
|
||||||
|
|
||||||
// comment-line ::= ws comment (newline / EOF)
|
|
||||||
auto comment_line = p.rule("comment-line", ws + comment + (newline | p.end()));
|
|
||||||
|
|
||||||
// blank-line ::= ws (newline / EOF)
|
|
||||||
auto blank_line = p.rule("blank-line", ws + (newline | p.end()));
|
|
||||||
|
|
||||||
// line ::= header-line / kv-line / comment-line / blank-line
|
|
||||||
auto line = p.rule("line", header_line | kv_line | comment_line | blank_line);
|
|
||||||
|
|
||||||
// ini ::= line* EOF
|
|
||||||
auto ini = p.rule("ini", p.zero_or_more(line) + p.end());
|
|
||||||
|
|
||||||
return ini;
|
|
||||||
});
|
|
||||||
|
|
||||||
common_peg_parse_context ctx(contents);
|
|
||||||
const auto result = parser.parse(ctx);
|
|
||||||
if (!result.success()) {
|
|
||||||
throw std::runtime_error("failed to parse server config file: " + path);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string current_section = COMMON_PRESET_DEFAULT_NAME;
|
|
||||||
std::string current_key;
|
|
||||||
|
|
||||||
ctx.ast.visit(result, [&](const auto & node) {
|
|
||||||
if (node.tag == "section-name") {
|
|
||||||
const std::string section = std::string(node.text);
|
|
||||||
current_section = section;
|
|
||||||
parsed[current_section] = {};
|
|
||||||
} else if (node.tag == "key") {
|
|
||||||
const std::string key = std::string(node.text);
|
|
||||||
current_key = key;
|
|
||||||
} else if (node.tag == "value" && !current_key.empty() && !current_section.empty()) {
|
|
||||||
parsed[current_section][current_key] = std::string(node.text);
|
|
||||||
current_key.clear();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return parsed;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::map<std::string, common_arg> get_map_key_opt(common_params_context & ctx_params) {
|
|
||||||
std::map<std::string, common_arg> mapping;
|
|
||||||
for (const auto & opt : ctx_params.options) {
|
|
||||||
for (const auto & env : opt.get_env()) {
|
|
||||||
mapping[env] = opt;
|
|
||||||
}
|
|
||||||
for (const auto & arg : opt.get_args()) {
|
|
||||||
mapping[rm_leading_dashes(arg)] = opt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mapping;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool is_bool_arg(const common_arg & arg) {
|
|
||||||
return !arg.args_neg.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string parse_bool_arg(const common_arg & arg, const std::string & key, const std::string & value) {
|
|
||||||
// if this is a negated arg, we need to reverse the value
|
|
||||||
for (const auto & neg_arg : arg.args_neg) {
|
|
||||||
if (rm_leading_dashes(neg_arg) == key) {
|
|
||||||
return common_arg_utils::is_truthy(value) ? "false" : "true";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// otherwise, not negated
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
|
|
||||||
: ctx_params(common_params_parser_init(default_params, ex)) {
|
|
||||||
common_params_add_preset_options(ctx_params.options);
|
|
||||||
key_to_opt = get_map_key_opt(ctx_params);
|
|
||||||
|
|
||||||
// setup allowed keys if only_remote_allowed is true
|
|
||||||
if (only_remote_allowed) {
|
|
||||||
filter_allowed_keys = true;
|
|
||||||
allowed_keys = get_remote_preset_whitelist(key_to_opt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
|
|
||||||
common_presets out;
|
|
||||||
auto ini_data = parse_ini_from_file(path);
|
|
||||||
|
|
||||||
for (auto section : ini_data) {
|
|
||||||
common_preset preset;
|
|
||||||
if (section.first.empty()) {
|
|
||||||
preset.name = COMMON_PRESET_DEFAULT_NAME;
|
|
||||||
} else {
|
|
||||||
preset.name = section.first;
|
|
||||||
}
|
|
||||||
LOG_DBG("loading preset: %s\n", preset.name.c_str());
|
|
||||||
for (const auto & [key, value] : section.second) {
|
|
||||||
if (key == "version") {
|
|
||||||
// skip version key (reserved for future use)
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
|
|
||||||
if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
|
|
||||||
throw std::runtime_error(string_format(
|
|
||||||
"option '%s' is not allowed in remote presets",
|
|
||||||
key.c_str()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if (key_to_opt.find(key) != key_to_opt.end()) {
|
|
||||||
const auto & opt = key_to_opt.at(key);
|
|
||||||
if (is_bool_arg(opt)) {
|
|
||||||
preset.options[opt] = parse_bool_arg(opt, key, value);
|
|
||||||
} else {
|
|
||||||
preset.options[opt] = value;
|
|
||||||
}
|
|
||||||
LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(string_format(
|
|
||||||
"option '%s' not recognized in preset '%s'",
|
|
||||||
key.c_str(), preset.name.c_str()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (preset.name == "*") {
|
|
||||||
// handle global preset
|
|
||||||
global = preset;
|
|
||||||
} else {
|
|
||||||
out[preset.name] = preset;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_presets common_preset_context::load_from_cache() const {
|
|
||||||
common_presets out;
|
|
||||||
|
|
||||||
auto cached_models = common_list_cached_models();
|
|
||||||
for (const auto & model : cached_models) {
|
|
||||||
common_preset preset;
|
|
||||||
preset.name = model.to_string();
|
|
||||||
preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string());
|
|
||||||
out[preset.name] = preset;
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct local_model {
|
|
||||||
std::string name;
|
|
||||||
std::string path;
|
|
||||||
std::string path_mmproj;
|
|
||||||
};
|
|
||||||
|
|
||||||
common_presets common_preset_context::load_from_models_dir(const std::string & models_dir) const {
|
|
||||||
if (!std::filesystem::exists(models_dir) || !std::filesystem::is_directory(models_dir)) {
|
|
||||||
throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", models_dir.c_str()));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<local_model> models;
|
|
||||||
auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
|
|
||||||
auto files = fs_list(subdir_path, false);
|
|
||||||
common_file_info model_file;
|
|
||||||
common_file_info first_shard_file;
|
|
||||||
common_file_info mmproj_file;
|
|
||||||
for (const auto & file : files) {
|
|
||||||
if (string_ends_with(file.name, ".gguf")) {
|
|
||||||
if (file.name.find("mmproj") != std::string::npos) {
|
|
||||||
mmproj_file = file;
|
|
||||||
} else if (file.name.find("-00001-of-") != std::string::npos) {
|
|
||||||
first_shard_file = file;
|
|
||||||
} else {
|
|
||||||
model_file = file;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// single file model
|
|
||||||
local_model model{
|
|
||||||
/* name */ name,
|
|
||||||
/* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
|
|
||||||
/* path_mmproj */ mmproj_file.path // can be empty
|
|
||||||
};
|
|
||||||
if (!model.path.empty()) {
|
|
||||||
models.push_back(model);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto files = fs_list(models_dir, true);
|
|
||||||
for (const auto & file : files) {
|
|
||||||
if (file.is_dir) {
|
|
||||||
scan_subdir(file.path, file.name);
|
|
||||||
} else if (string_ends_with(file.name, ".gguf")) {
|
|
||||||
// single file model
|
|
||||||
std::string name = file.name;
|
|
||||||
string_replace_all(name, ".gguf", "");
|
|
||||||
local_model model{
|
|
||||||
/* name */ name,
|
|
||||||
/* path */ file.path,
|
|
||||||
/* path_mmproj */ ""
|
|
||||||
};
|
|
||||||
models.push_back(model);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert local models to presets
|
|
||||||
common_presets out;
|
|
||||||
for (const auto & model : models) {
|
|
||||||
common_preset preset;
|
|
||||||
preset.name = model.name;
|
|
||||||
preset.set_option(*this, "LLAMA_ARG_MODEL", model.path);
|
|
||||||
if (!model.path_mmproj.empty()) {
|
|
||||||
preset.set_option(*this, "LLAMA_ARG_MMPROJ", model.path_mmproj);
|
|
||||||
}
|
|
||||||
out[preset.name] = preset;
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_preset common_preset_context::load_from_args(int argc, char ** argv) const {
|
|
||||||
common_preset preset;
|
|
||||||
preset.name = COMMON_PRESET_DEFAULT_NAME;
|
|
||||||
|
|
||||||
bool ok = common_params_to_map(argc, argv, ctx_params.ex, preset.options);
|
|
||||||
if (!ok) {
|
|
||||||
throw std::runtime_error("failed to parse CLI arguments into preset");
|
|
||||||
}
|
|
||||||
|
|
||||||
return preset;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_presets common_preset_context::cascade(const common_presets & base, const common_presets & added) const {
|
|
||||||
common_presets out = base; // copy
|
|
||||||
for (const auto & [name, preset_added] : added) {
|
|
||||||
if (out.find(name) != out.end()) {
|
|
||||||
// if exists, merge
|
|
||||||
common_preset & target = out[name];
|
|
||||||
target.merge(preset_added);
|
|
||||||
} else {
|
|
||||||
// otherwise, add directly
|
|
||||||
out[name] = preset_added;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_presets common_preset_context::cascade(const common_preset & base, const common_presets & presets) const {
|
|
||||||
common_presets out;
|
|
||||||
for (const auto & [name, preset] : presets) {
|
|
||||||
common_preset tmp = base; // copy
|
|
||||||
tmp.name = name;
|
|
||||||
tmp.merge(preset);
|
|
||||||
out[name] = std::move(tmp);
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "arg.h"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
//
|
|
||||||
// INI preset parser and writer
|
|
||||||
//
|
|
||||||
|
|
||||||
constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default";
|
|
||||||
|
|
||||||
struct common_preset_context;
|
|
||||||
|
|
||||||
struct common_preset {
|
|
||||||
std::string name;
|
|
||||||
|
|
||||||
// options are stored as common_arg to string mapping, representing CLI arg and its value
|
|
||||||
std::map<common_arg, std::string> options;
|
|
||||||
|
|
||||||
// convert preset to CLI argument list
|
|
||||||
std::vector<std::string> to_args(const std::string & bin_path = "") const;
|
|
||||||
|
|
||||||
// convert preset to INI format string
|
|
||||||
std::string to_ini() const;
|
|
||||||
|
|
||||||
// TODO: maybe implement to_env() if needed
|
|
||||||
|
|
||||||
// modify preset options where argument is identified by its env variable
|
|
||||||
void set_option(const common_preset_context & ctx, const std::string & env, const std::string & value);
|
|
||||||
|
|
||||||
// unset option by its env variable
|
|
||||||
void unset_option(const std::string & env);
|
|
||||||
|
|
||||||
// get option value by its env variable, return false if not found
|
|
||||||
bool get_option(const std::string & env, std::string & value) const;
|
|
||||||
|
|
||||||
// merge another preset into this one, overwriting existing options
|
|
||||||
void merge(const common_preset & other);
|
|
||||||
|
|
||||||
// apply preset options to common_params
|
|
||||||
void apply_to_params(common_params & params) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
// interface for multiple presets in one file
|
|
||||||
using common_presets = std::map<std::string, common_preset>;
|
|
||||||
|
|
||||||
// context for loading and editing presets
|
|
||||||
struct common_preset_context {
|
|
||||||
common_params default_params; // unused for now
|
|
||||||
common_params_context ctx_params;
|
|
||||||
std::map<std::string, common_arg> key_to_opt;
|
|
||||||
|
|
||||||
bool filter_allowed_keys = false;
|
|
||||||
std::set<std::string> allowed_keys;
|
|
||||||
|
|
||||||
// if only_remote_allowed is true, only accept whitelisted keys
|
|
||||||
common_preset_context(llama_example ex, bool only_remote_allowed = false);
|
|
||||||
|
|
||||||
// load presets from INI file
|
|
||||||
common_presets load_from_ini(const std::string & path, common_preset & global) const;
|
|
||||||
|
|
||||||
// generate presets from cached models
|
|
||||||
common_presets load_from_cache() const;
|
|
||||||
|
|
||||||
// generate presets from local models directory
|
|
||||||
// for the directory structure, see "Using multiple models" in server/README.md
|
|
||||||
common_presets load_from_models_dir(const std::string & models_dir) const;
|
|
||||||
|
|
||||||
// generate one preset from CLI arguments
|
|
||||||
common_preset load_from_args(int argc, char ** argv) const;
|
|
||||||
|
|
||||||
// cascade multiple presets if exist on both: base < added
|
|
||||||
// if preset does not exist in base, it will be added without modification
|
|
||||||
common_presets cascade(const common_presets & base, const common_presets & added) const;
|
|
||||||
|
|
||||||
// apply presets over a base preset (same idea as CSS cascading)
|
|
||||||
common_presets cascade(const common_preset & base, const common_presets & presets) const;
|
|
||||||
};
|
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
#include "regex-partial.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include <functional>
|
|
||||||
#include <optional>
|
|
||||||
|
|
||||||
common_regex::common_regex(const std::string & pattern) :
|
|
||||||
pattern(pattern),
|
|
||||||
rx(pattern),
|
|
||||||
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
|
||||||
|
|
||||||
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
|
||||||
std::smatch match;
|
|
||||||
if (pos > input.size()) {
|
|
||||||
throw std::runtime_error("Position out of bounds");
|
|
||||||
}
|
|
||||||
auto start = input.begin() + pos;
|
|
||||||
auto found = as_match
|
|
||||||
? std::regex_match(start, input.end(), match, rx)
|
|
||||||
: std::regex_search(start, input.end(), match, rx);
|
|
||||||
if (found) {
|
|
||||||
common_regex_match res;
|
|
||||||
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
|
||||||
for (size_t i = 0; i < match.size(); ++i) {
|
|
||||||
auto begin = pos + match.position(i);
|
|
||||||
res.groups.emplace_back(begin, begin + match.length(i));
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
std::match_results<std::string::const_reverse_iterator> srmatch;
|
|
||||||
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
|
|
||||||
auto group = srmatch[1].str();
|
|
||||||
if (group.length() != 0) {
|
|
||||||
auto it = srmatch[1].second.base();
|
|
||||||
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
|
||||||
if ((!as_match) || it == input.begin()) {
|
|
||||||
common_regex_match res;
|
|
||||||
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
|
||||||
const size_t begin = std::distance(input.begin(), it);
|
|
||||||
const size_t end = input.size();
|
|
||||||
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
|
||||||
throw std::runtime_error("Invalid range");
|
|
||||||
}
|
|
||||||
res.groups.push_back({begin, end});
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
|
||||||
|
|
||||||
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
|
||||||
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
|
||||||
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
|
||||||
|
|
||||||
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
|
|
||||||
- /a|b/ -> ^(a|b)
|
|
||||||
- /a*?/ -> error, could match ""
|
|
||||||
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
|
|
||||||
- /.*?ab/ -> ^((?:b)?a) (omit .*)
|
|
||||||
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
|
|
||||||
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
|
|
||||||
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
|
|
||||||
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
|
|
||||||
|
|
||||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
|
|
||||||
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
|
|
||||||
*/
|
|
||||||
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
|
||||||
auto it = pattern.begin();
|
|
||||||
const auto end = pattern.end();
|
|
||||||
|
|
||||||
std::function<std::string()> process = [&]() {
|
|
||||||
std::vector<std::vector<std::string>> alternatives(1);
|
|
||||||
std::vector<std::string> * sequence = &alternatives.back();
|
|
||||||
|
|
||||||
while (it != end) {
|
|
||||||
if (*it == '[') {
|
|
||||||
auto start = it;
|
|
||||||
++it;
|
|
||||||
while (it != end) {
|
|
||||||
if ((*it == '\\') && (++it != end)) {
|
|
||||||
++it;
|
|
||||||
} else if ((it != end) && (*it == ']')) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (it == end) {
|
|
||||||
throw std::runtime_error("Unmatched '[' in pattern");
|
|
||||||
}
|
|
||||||
++it;
|
|
||||||
sequence->push_back(std::string(start, it));
|
|
||||||
} else if (*it == '*' || *it == '?' || *it == '+') {
|
|
||||||
if (sequence->empty()) {
|
|
||||||
throw std::runtime_error("Quantifier without preceding element");
|
|
||||||
}
|
|
||||||
sequence->back() += *it;
|
|
||||||
auto is_star = *it == '*';
|
|
||||||
++it;
|
|
||||||
if (is_star) {
|
|
||||||
if (*it == '?') {
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (*it == '{') {
|
|
||||||
if (sequence->empty()) {
|
|
||||||
throw std::runtime_error("Repetition without preceding element");
|
|
||||||
}
|
|
||||||
++it;
|
|
||||||
auto start = it;
|
|
||||||
while (it != end && *it != '}') {
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
if (it == end) {
|
|
||||||
throw std::runtime_error("Unmatched '{' in pattern");
|
|
||||||
}
|
|
||||||
auto parts = string_split(std::string(start, it), ",");
|
|
||||||
++it;
|
|
||||||
if (parts.size() > 2) {
|
|
||||||
throw std::runtime_error("Invalid repetition range in pattern");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
|
||||||
if (s.empty()) {
|
|
||||||
return def;
|
|
||||||
}
|
|
||||||
return std::stoi(s);
|
|
||||||
};
|
|
||||||
auto min = parseOptInt(parts[0], 0);
|
|
||||||
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
|
||||||
if (min && max && *max < *min) {
|
|
||||||
throw std::runtime_error("Invalid repetition range in pattern");
|
|
||||||
}
|
|
||||||
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
|
||||||
auto part = sequence->back();
|
|
||||||
sequence->pop_back();
|
|
||||||
for (int i = 0; i < *min; i++) {
|
|
||||||
sequence->push_back(part);
|
|
||||||
}
|
|
||||||
if (max) {
|
|
||||||
for (int i = *min; i < *max; i++) {
|
|
||||||
sequence->push_back(part + "?");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sequence->push_back(part + "*");
|
|
||||||
}
|
|
||||||
} else if (*it == '(') {
|
|
||||||
++it;
|
|
||||||
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
|
||||||
it += 2;
|
|
||||||
}
|
|
||||||
auto sub = process();
|
|
||||||
if (*it != ')') {
|
|
||||||
throw std::runtime_error("Unmatched '(' in pattern");
|
|
||||||
}
|
|
||||||
++it;
|
|
||||||
auto & part = sequence->emplace_back("(?:");
|
|
||||||
part += sub;
|
|
||||||
part += ")";
|
|
||||||
} else if (*it == ')') {
|
|
||||||
break;
|
|
||||||
} else if (*it == '|') {
|
|
||||||
++it;
|
|
||||||
alternatives.emplace_back();
|
|
||||||
sequence = &alternatives.back();
|
|
||||||
} else if (*it == '\\' && (++it != end)) {
|
|
||||||
auto str = std::string("\\") + *it;
|
|
||||||
sequence->push_back(str);
|
|
||||||
++it;
|
|
||||||
} else if (it != end) {
|
|
||||||
sequence->push_back(std::string(1, *it));
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
|
|
||||||
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
|
||||||
// We'll do the outermost capturing group and final .* in the enclosing function.
|
|
||||||
std::vector<std::string> res_alts;
|
|
||||||
for (const auto & parts : alternatives) {
|
|
||||||
auto & res = res_alts.emplace_back();
|
|
||||||
for (size_t i = 0; i < parts.size() - 1; i++) {
|
|
||||||
res += "(?:";
|
|
||||||
}
|
|
||||||
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
|
||||||
res += *it;
|
|
||||||
if (it != parts.rend() - 1) {
|
|
||||||
res += ")?";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return string_join(res_alts, "|");
|
|
||||||
};
|
|
||||||
auto res = process();
|
|
||||||
if (it != end) {
|
|
||||||
throw std::runtime_error("Unmatched '(' in pattern");
|
|
||||||
}
|
|
||||||
|
|
||||||
return "^(" + res + ")";
|
|
||||||
}
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <regex>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
enum common_regex_match_type {
|
|
||||||
COMMON_REGEX_MATCH_TYPE_NONE,
|
|
||||||
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
|
||||||
COMMON_REGEX_MATCH_TYPE_FULL,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_string_range {
|
|
||||||
size_t begin;
|
|
||||||
size_t end;
|
|
||||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
|
||||||
if (begin > end) {
|
|
||||||
throw std::runtime_error("Invalid range");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// prevent default ctor
|
|
||||||
common_string_range() = delete;
|
|
||||||
bool empty() const {
|
|
||||||
return begin == end;
|
|
||||||
}
|
|
||||||
bool operator==(const common_string_range & other) const {
|
|
||||||
return begin == other.begin && end == other.end;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_regex_match {
|
|
||||||
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
|
||||||
std::vector<common_string_range> groups;
|
|
||||||
|
|
||||||
bool operator==(const common_regex_match & other) const {
|
|
||||||
return type == other.type && groups == other.groups;
|
|
||||||
}
|
|
||||||
bool operator!=(const common_regex_match & other) const {
|
|
||||||
return !(*this == other);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class common_regex {
|
|
||||||
std::string pattern;
|
|
||||||
std::regex rx;
|
|
||||||
std::regex rx_reversed_partial;
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit common_regex(const std::string & pattern);
|
|
||||||
|
|
||||||
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
|
||||||
|
|
||||||
const std::string & str() const { return pattern; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// For testing only (pretty print of failures).
|
|
||||||
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
|
||||||
@@ -1,712 +0,0 @@
|
|||||||
#include "sampling.h"
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "log.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstring>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
|
||||||
// TODO: deduplicate with llama-impl.h
|
|
||||||
template<typename T>
|
|
||||||
struct ring_buffer {
|
|
||||||
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
|
||||||
|
|
||||||
T & front() {
|
|
||||||
if (sz == 0) {
|
|
||||||
throw std::runtime_error("ring buffer is empty");
|
|
||||||
}
|
|
||||||
return data[first];
|
|
||||||
}
|
|
||||||
|
|
||||||
const T & front() const {
|
|
||||||
if (sz == 0) {
|
|
||||||
throw std::runtime_error("ring buffer is empty");
|
|
||||||
}
|
|
||||||
return data[first];
|
|
||||||
}
|
|
||||||
|
|
||||||
T & back() {
|
|
||||||
if (sz == 0) {
|
|
||||||
throw std::runtime_error("ring buffer is empty");
|
|
||||||
}
|
|
||||||
return data[pos];
|
|
||||||
}
|
|
||||||
|
|
||||||
const T & back() const {
|
|
||||||
if (sz == 0) {
|
|
||||||
throw std::runtime_error("ring buffer is empty");
|
|
||||||
}
|
|
||||||
return data[pos];
|
|
||||||
}
|
|
||||||
|
|
||||||
void push_back(const T & value) {
|
|
||||||
if (sz == capacity) {
|
|
||||||
// advance the start when buffer is full
|
|
||||||
first = (first + 1) % capacity;
|
|
||||||
} else {
|
|
||||||
sz++;
|
|
||||||
}
|
|
||||||
data[pos] = value;
|
|
||||||
pos = (pos + 1) % capacity;
|
|
||||||
}
|
|
||||||
|
|
||||||
T pop_front() {
|
|
||||||
if (sz == 0) {
|
|
||||||
throw std::runtime_error("ring buffer is empty");
|
|
||||||
}
|
|
||||||
T value = data[first];
|
|
||||||
first = (first + 1) % capacity;
|
|
||||||
sz--;
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
const T & rat(size_t i) const {
|
|
||||||
if (i >= sz) {
|
|
||||||
throw std::runtime_error("ring buffer: index out of bounds");
|
|
||||||
}
|
|
||||||
return data[(first + sz - i - 1) % capacity];
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<T> to_vector() const {
|
|
||||||
std::vector<T> result;
|
|
||||||
result.reserve(sz);
|
|
||||||
for (size_t i = 0; i < sz; i++) {
|
|
||||||
result.push_back(data[(first + i) % capacity]);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void clear() {
|
|
||||||
// here only reset the status of the buffer
|
|
||||||
sz = 0;
|
|
||||||
first = 0;
|
|
||||||
pos = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool empty() const {
|
|
||||||
return sz == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t size() const {
|
|
||||||
return sz;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t capacity = 0;
|
|
||||||
size_t sz = 0;
|
|
||||||
size_t first = 0;
|
|
||||||
size_t pos = 0;
|
|
||||||
std::vector<T> data;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_sampler {
|
|
||||||
common_params_sampling params;
|
|
||||||
|
|
||||||
struct llama_sampler * grmr;
|
|
||||||
struct llama_sampler * chain;
|
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
|
|
||||||
llama_token_data_array cur_p;
|
|
||||||
|
|
||||||
void reset() {
|
|
||||||
prev.clear();
|
|
||||||
|
|
||||||
llama_sampler_reset(chain);
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_logits(struct llama_context * ctx, int idx) {
|
|
||||||
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
|
||||||
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
|
||||||
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
|
||||||
|
|
||||||
const llama_model * model = llama_get_model(ctx);
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
||||||
|
|
||||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
||||||
|
|
||||||
if (sampled_probs) {
|
|
||||||
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
|
||||||
cur.resize(sampled_probs_count);
|
|
||||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
|
||||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
|
||||||
}
|
|
||||||
} else if (sampled_logits) {
|
|
||||||
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
|
||||||
cur.resize(sampled_logits_count);
|
|
||||||
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
|
||||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
||||||
GGML_ASSERT(logits != nullptr);
|
|
||||||
cur.resize(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
}
|
|
||||||
|
|
||||||
common_time_meas tm() {
|
|
||||||
return common_time_meas(t_total_us, params.no_perf);
|
|
||||||
}
|
|
||||||
|
|
||||||
mutable int64_t t_total_us = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string common_params_sampling::print() const {
|
|
||||||
char result[1024];
|
|
||||||
|
|
||||||
snprintf(result, sizeof(result),
|
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
||||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
|
||||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
|
||||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
|
||||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
|
||||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
|
||||||
mirostat, mirostat_eta, mirostat_tau);
|
|
||||||
|
|
||||||
return std::string(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
||||||
|
|
||||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
|
||||||
|
|
||||||
lparams.no_perf = params.no_perf;
|
|
||||||
|
|
||||||
llama_sampler * grmr = nullptr;
|
|
||||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
|
||||||
|
|
||||||
std::vector<llama_sampler *> samplers;
|
|
||||||
|
|
||||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
|
||||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
|
||||||
#else
|
|
||||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
|
||||||
} else {
|
|
||||||
std::vector<std::string> trigger_patterns;
|
|
||||||
std::vector<llama_token> trigger_tokens;
|
|
||||||
for (const auto & trigger : params.grammar_triggers) {
|
|
||||||
switch (trigger.type) {
|
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
|
||||||
{
|
|
||||||
const auto & word = trigger.value;
|
|
||||||
trigger_patterns.push_back(regex_escape(word));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
|
||||||
{
|
|
||||||
trigger_patterns.push_back(trigger.value);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
|
||||||
{
|
|
||||||
const auto & pattern = trigger.value;
|
|
||||||
std::string anchored = "^$";
|
|
||||||
if (!pattern.empty()) {
|
|
||||||
anchored = (pattern.front() != '^' ? "^" : "")
|
|
||||||
+ pattern
|
|
||||||
+ (pattern.back() != '$' ? "$" : "");
|
|
||||||
}
|
|
||||||
trigger_patterns.push_back(anchored);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
|
||||||
{
|
|
||||||
const auto token = trigger.token;
|
|
||||||
trigger_tokens.push_back(token);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false && "unknown trigger type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const char *> trigger_patterns_c;
|
|
||||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
|
||||||
for (const auto & regex : trigger_patterns) {
|
|
||||||
trigger_patterns_c.push_back(regex.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!params.grammar.empty()) {
|
|
||||||
if (params.grammar_lazy) {
|
|
||||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
|
||||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
|
||||||
trigger_tokens.data(), trigger_tokens.size());
|
|
||||||
} else {
|
|
||||||
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.has_logit_bias()) {
|
|
||||||
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.mirostat == 0) {
|
|
||||||
for (const auto & cnstr : params.samplers) {
|
|
||||||
switch (cnstr) {
|
|
||||||
case COMMON_SAMPLER_TYPE_DRY:
|
|
||||||
{
|
|
||||||
std::vector<const char *> c_breakers;
|
|
||||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
|
||||||
for (const auto & str : params.dry_sequence_breakers) {
|
|
||||||
c_breakers.push_back(str.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
||||||
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
||||||
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
|
||||||
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
||||||
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_XTC:
|
|
||||||
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
||||||
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
||||||
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_INFILL:
|
|
||||||
samplers.push_back(llama_sampler_init_infill (vocab));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
||||||
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
|
||||||
} else if (params.mirostat == 1) {
|
|
||||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
|
||||||
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
||||||
} else if (params.mirostat == 2) {
|
|
||||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
|
||||||
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto * smpl : samplers) {
|
|
||||||
llama_sampler_chain_add(chain, smpl);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (grmr && params.backend_sampling) {
|
|
||||||
LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);
|
|
||||||
|
|
||||||
params.backend_sampling = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto * result = new common_sampler {
|
|
||||||
/* .params = */ params,
|
|
||||||
/* .grmr = */ grmr,
|
|
||||||
/* .chain = */ chain,
|
|
||||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
|
||||||
/* .cur = */ {},
|
|
||||||
/* .cur_p = */ {},
|
|
||||||
};
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
|
||||||
if (gsmpl) {
|
|
||||||
llama_sampler_free(gsmpl->grmr);
|
|
||||||
llama_sampler_free(gsmpl->chain);
|
|
||||||
|
|
||||||
delete gsmpl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
|
||||||
const auto tm = gsmpl->tm();
|
|
||||||
|
|
||||||
if (gsmpl->grmr && accept_grammar) {
|
|
||||||
llama_sampler_accept(gsmpl->grmr, token);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_accept(gsmpl->chain, token);
|
|
||||||
|
|
||||||
gsmpl->prev.push_back(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
|
||||||
gsmpl->reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
|
||||||
return new common_sampler {
|
|
||||||
/* .params = */ gsmpl->params,
|
|
||||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
|
||||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
|
||||||
/* .prev = */ gsmpl->prev,
|
|
||||||
/* .cur = */ gsmpl->cur,
|
|
||||||
/* .cur_p = */ gsmpl->cur_p,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
|
||||||
// TODO: measure grammar performance
|
|
||||||
|
|
||||||
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
|
||||||
|
|
||||||
llama_perf_sampler_data data_smpl;
|
|
||||||
llama_perf_context_data data_ctx;
|
|
||||||
|
|
||||||
memset(&data_smpl, 0, sizeof(data_smpl));
|
|
||||||
memset(&data_ctx, 0, sizeof(data_ctx));
|
|
||||||
|
|
||||||
if (gsmpl) {
|
|
||||||
auto & data = data_smpl;
|
|
||||||
|
|
||||||
data = llama_perf_sampler(gsmpl->chain);
|
|
||||||
|
|
||||||
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
|
||||||
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
|
||||||
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx) {
|
|
||||||
auto & data = data_ctx;
|
|
||||||
|
|
||||||
data = llama_perf_context(ctx);
|
|
||||||
|
|
||||||
const double t_end_ms = 1e-3 * ggml_time_us();
|
|
||||||
|
|
||||||
const double t_total_ms = t_end_ms - data.t_start_ms;
|
|
||||||
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
|
||||||
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
|
||||||
|
|
||||||
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
|
||||||
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
||||||
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
|
||||||
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
||||||
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
|
||||||
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
|
||||||
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
|
||||||
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
|
||||||
|
|
||||||
llama_memory_breakdown_print(ctx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
|
||||||
return gsmpl->chain;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
|
||||||
llama_synchronize(ctx);
|
|
||||||
|
|
||||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
|
||||||
const auto tm = gsmpl->tm();
|
|
||||||
|
|
||||||
llama_token id = LLAMA_TOKEN_NULL;
|
|
||||||
|
|
||||||
auto & grmr = gsmpl->grmr;
|
|
||||||
auto & chain = gsmpl->chain;
|
|
||||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
|
||||||
|
|
||||||
// Check if a backend sampler has already sampled a token in which case we
|
|
||||||
// return that token id directly.
|
|
||||||
{
|
|
||||||
id = llama_get_sampled_token_ith(ctx, idx);
|
|
||||||
|
|
||||||
if (id != LLAMA_TOKEN_NULL) {
|
|
||||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
|
||||||
|
|
||||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
|
||||||
|
|
||||||
// TODO: simplify
|
|
||||||
gsmpl->cur.resize(1);
|
|
||||||
gsmpl->cur[0] = { id, 0.0f, 1.0f };
|
|
||||||
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
|
|
||||||
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
gsmpl->set_logits(ctx, idx);
|
|
||||||
|
|
||||||
if (grammar_first) {
|
|
||||||
llama_sampler_apply(grmr, &cur_p);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_apply(chain, &cur_p);
|
|
||||||
|
|
||||||
id = cur_p.data[cur_p.selected].id;
|
|
||||||
|
|
||||||
if (grammar_first) {
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
|
|
||||||
{
|
|
||||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
|
||||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
|
||||||
|
|
||||||
llama_sampler_apply(grmr, &single_token_data_array);
|
|
||||||
|
|
||||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
||||||
if (is_valid) {
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// resampling:
|
|
||||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
|
||||||
gsmpl->set_logits(ctx, idx);
|
|
||||||
|
|
||||||
llama_sampler_apply(grmr, &cur_p);
|
|
||||||
llama_sampler_apply(chain, &cur_p);
|
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
|
||||||
|
|
||||||
id = cur_p.data[cur_p.selected].id;
|
|
||||||
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
|
||||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
|
||||||
|
|
||||||
std::vector<llama_token> result;
|
|
||||||
result.reserve(idxs.size());
|
|
||||||
|
|
||||||
size_t i = 0;
|
|
||||||
for (; i < draft.size(); i++) {
|
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
|
||||||
|
|
||||||
common_sampler_accept(gsmpl, id, true);
|
|
||||||
|
|
||||||
result.push_back(id);
|
|
||||||
|
|
||||||
if (draft[i] != id) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i == draft.size()) {
|
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
|
||||||
|
|
||||||
common_sampler_accept(gsmpl, id, true);
|
|
||||||
|
|
||||||
result.push_back(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
|
||||||
std::vector<int> idxs(draft.size() + 1);
|
|
||||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
||||||
idxs[i] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
|
||||||
return llama_sampler_get_seed(gsmpl->chain);
|
|
||||||
}
|
|
||||||
|
|
||||||
// helpers
|
|
||||||
|
|
||||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
|
||||||
const auto tm = gsmpl->tm();
|
|
||||||
|
|
||||||
auto * res = &gsmpl->cur_p;
|
|
||||||
|
|
||||||
if (do_sort && !res->sorted) {
|
|
||||||
// remember the selected token before sorting
|
|
||||||
const llama_token id = res->data[res->selected].id;
|
|
||||||
|
|
||||||
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
|
|
||||||
return a.p > b.p;
|
|
||||||
});
|
|
||||||
|
|
||||||
// restore the selected token after sorting
|
|
||||||
for (size_t i = 0; i < res->size; ++i) {
|
|
||||||
if (res->data[i].id == id) {
|
|
||||||
res->selected = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res->sorted = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
|
|
||||||
return gsmpl->prev.rat(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
|
||||||
std::string result = "logits ";
|
|
||||||
|
|
||||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
|
||||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
||||||
result += std::string("-> ");
|
|
||||||
result += std::string(llama_sampler_name(smpl)) + " ";
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
|
|
||||||
n = std::min(n, (int) gsmpl->prev.size());
|
|
||||||
|
|
||||||
if (n <= 0) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string result;
|
|
||||||
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
|
|
||||||
|
|
||||||
for (int i = n - 1; i >= 0; i--) {
|
|
||||||
const llama_token id = gsmpl->prev.rat(i);
|
|
||||||
|
|
||||||
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
|
|
||||||
|
|
||||||
result += common_token_to_piece(ctx_main, id);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|
||||||
switch (cnstr) {
|
|
||||||
case COMMON_SAMPLER_TYPE_DRY: return 'd';
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
|
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
|
||||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
|
||||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
|
||||||
default : return '?';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|
||||||
switch (cnstr) {
|
|
||||||
case COMMON_SAMPLER_TYPE_DRY: return "dry";
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
|
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
|
||||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
|
||||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
|
||||||
default : return "";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
|
||||||
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
|
||||||
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
|
||||||
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
|
||||||
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
|
||||||
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
||||||
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
||||||
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
||||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
||||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
|
||||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
|
||||||
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
|
||||||
};
|
|
||||||
|
|
||||||
// since samplers names are written multiple ways
|
|
||||||
// make it ready for both system names and input names
|
|
||||||
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
|
|
||||||
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
|
|
||||||
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
|
|
||||||
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
||||||
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
|
|
||||||
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
||||||
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
||||||
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
||||||
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
||||||
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
||||||
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
|
||||||
samplers.reserve(names.size());
|
|
||||||
|
|
||||||
for (const auto & name : names) {
|
|
||||||
auto sampler = sampler_canonical_name_map.find(name);
|
|
||||||
if (sampler != sampler_canonical_name_map.end()) {
|
|
||||||
samplers.push_back(sampler->second);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (allow_alt_names) {
|
|
||||||
sampler = sampler_alt_name_map.find(name);
|
|
||||||
if (sampler != sampler_alt_name_map.end()) {
|
|
||||||
samplers.push_back(sampler->second);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
return samplers;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
|
||||||
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
|
||||||
samplers.reserve(chars.size());
|
|
||||||
|
|
||||||
for (const auto & c : chars) {
|
|
||||||
const auto sampler = sampler_name_map.find(c);
|
|
||||||
if (sampler != sampler_name_map.end()) {
|
|
||||||
samplers.push_back(sampler->second);
|
|
||||||
} else {
|
|
||||||
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return samplers;
|
|
||||||
}
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "llama.h"
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
// common_sampler extends llama_sampler with additional functionality:
|
|
||||||
//
|
|
||||||
// - grammar support
|
|
||||||
// - custom sampler logic based on the parameters
|
|
||||||
// - history of the last accepted tokens
|
|
||||||
// - performance metrics
|
|
||||||
//
|
|
||||||
// This goal is to have a common implementation of the sampling logic shared across the examples.
|
|
||||||
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
|
|
||||||
// complex (top-k, top-p, etc).
|
|
||||||
//
|
|
||||||
// Another example is related to the grammar. In general, the grammar constraints applied on the full
|
|
||||||
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
|
|
||||||
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
|
|
||||||
// grammar constraints are applied to the full vocabulary and the token is resampled.
|
|
||||||
//
|
|
||||||
// The common_sampler also maintains a container with the last accepted tokens. In the future, this can
|
|
||||||
// be moved into the core llama library.
|
|
||||||
//
|
|
||||||
// For convenience, the common_sampler also maintains a container with the current candidate tokens.
|
|
||||||
// This can be used to access the probabilities of the rest of the non-sampled tokens.
|
|
||||||
//
|
|
||||||
// TODO: measure grammar performance
|
|
||||||
//
|
|
||||||
|
|
||||||
struct common_sampler;
|
|
||||||
|
|
||||||
// llama_sampler API overloads
|
|
||||||
|
|
||||||
// note: can mutate params in some cases
|
|
||||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
|
|
||||||
|
|
||||||
void common_sampler_free(struct common_sampler * gsmpl);
|
|
||||||
|
|
||||||
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
|
||||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
|
|
||||||
void common_sampler_reset (struct common_sampler * gsmpl);
|
|
||||||
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
|
||||||
|
|
||||||
// arguments can be nullptr to skip printing
|
|
||||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
|
||||||
|
|
||||||
// get the underlying llama_sampler_chain
|
|
||||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
|
||||||
|
|
||||||
// extended sampling implementation:
|
|
||||||
//
|
|
||||||
// - set logits
|
|
||||||
// - apply the configured sampler chain
|
|
||||||
// - check if the token fits the grammar (if any)
|
|
||||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
|
||||||
//
|
|
||||||
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
|
||||||
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
|
||||||
//
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
|
||||||
|
|
||||||
// generalized version of common_sampler_sample
|
|
||||||
//
|
|
||||||
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
|
||||||
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
|
||||||
//
|
|
||||||
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
|
||||||
//
|
|
||||||
// is equivalent to
|
|
||||||
//
|
|
||||||
// common_sampler_sample(gsmpl, ctx, idx);
|
|
||||||
// common_sampler_accept(gsmpl, token, true);
|
|
||||||
//
|
|
||||||
// requires: idxs.size() == draft.size() + 1
|
|
||||||
//
|
|
||||||
// returns at least 1 token, up to idxs.size()
|
|
||||||
//
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
|
||||||
|
|
||||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
|
||||||
|
|
||||||
// helpers
|
|
||||||
|
|
||||||
// access the internal list of current candidate tokens
|
|
||||||
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
|
|
||||||
// the .sorted flag of the result indicates whether the returned candidates are sorted
|
|
||||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
|
|
||||||
|
|
||||||
// get the last accepted token
|
|
||||||
llama_token common_sampler_last(const struct common_sampler * gsmpl);
|
|
||||||
|
|
||||||
// print the sampler chain into a string
|
|
||||||
std::string common_sampler_print(const struct common_sampler * gsmpl);
|
|
||||||
|
|
||||||
// get a string representation of the last accepted tokens
|
|
||||||
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
|
|
||||||
|
|
||||||
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
|
|
||||||
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
|
||||||
|
|
||||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
|
||||||
const char * grammar_kind, const char * grammar_data);
|
|
||||||
|
|
||||||
struct common_sampler_deleter {
|
|
||||||
void operator()(common_sampler * s) { common_sampler_free(s); }
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;
|
|
||||||
@@ -1,361 +0,0 @@
|
|||||||
#include "speculative.h"
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "llama.h"
|
|
||||||
#include "log.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "sampling.h"
|
|
||||||
|
|
||||||
#include <cstring>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
|
||||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
||||||
|
|
||||||
struct common_speculative {
|
|
||||||
struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
|
||||||
struct llama_context * ctx_dft;
|
|
||||||
struct common_sampler * smpl;
|
|
||||||
|
|
||||||
llama_batch batch;
|
|
||||||
llama_tokens prompt_dft;
|
|
||||||
bool vocab_dft_compatible = true; // whether retokenization is needed
|
|
||||||
std::map<std::string, std::string> tgt_dft_replacements = {};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_speculative * common_speculative_init(
|
|
||||||
struct llama_context * ctx_tgt,
|
|
||||||
struct llama_context * ctx_dft) {
|
|
||||||
auto * result = new common_speculative {
|
|
||||||
/* .ctx_tgt = */ ctx_tgt,
|
|
||||||
/* .ctx_dft = */ ctx_dft,
|
|
||||||
/* .smpl = */ nullptr,
|
|
||||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
|
||||||
/* .prompt_dft = */ {},
|
|
||||||
/* .vocab_dft_compatible = */ false,
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: optimize or pass from outside?
|
|
||||||
#if 0
|
|
||||||
{
|
|
||||||
common_params_sampling params;
|
|
||||||
params.no_perf = false;
|
|
||||||
|
|
||||||
params.top_k = 40;
|
|
||||||
params.top_p = 0.9;
|
|
||||||
|
|
||||||
params.samplers = {
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_K,
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_P,
|
|
||||||
COMMON_SAMPLER_TYPE_INFILL,
|
|
||||||
};
|
|
||||||
|
|
||||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
{
|
|
||||||
common_params_sampling params;
|
|
||||||
params.no_perf = false;
|
|
||||||
|
|
||||||
params.top_k = 10;
|
|
||||||
|
|
||||||
params.samplers = {
|
|
||||||
COMMON_SAMPLER_TYPE_TOP_K,
|
|
||||||
};
|
|
||||||
|
|
||||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
|
|
||||||
LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_speculative_free(struct common_speculative * spec) {
|
|
||||||
if (spec == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_sampler_free(spec->smpl);
|
|
||||||
|
|
||||||
llama_batch_free(spec->batch);
|
|
||||||
|
|
||||||
delete spec;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool common_speculative_are_compatible(
|
|
||||||
const struct llama_context * ctx_tgt,
|
|
||||||
const struct llama_context * ctx_dft) {
|
|
||||||
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
|
|
||||||
const struct llama_model * model_dft = llama_get_model(ctx_dft);
|
|
||||||
|
|
||||||
const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
|
||||||
const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
|
||||||
|
|
||||||
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
|
|
||||||
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
|
||||||
|
|
||||||
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
|
|
||||||
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
|
||||||
|
|
||||||
if (vocab_type_tgt != vocab_type_dft) {
|
|
||||||
LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
|
|
||||||
LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
|
||||||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
|
||||||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
|
||||||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
|
||||||
) {
|
|
||||||
LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
|
||||||
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
|
||||||
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
|
||||||
? n_vocab_tgt - n_vocab_dft
|
|
||||||
: n_vocab_dft - n_vocab_tgt;
|
|
||||||
|
|
||||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
||||||
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
|
||||||
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
|
||||||
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
|
||||||
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
|
||||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
|
||||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
||||||
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
|
|
||||||
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
|
|
||||||
common_token_to_piece(ctx_tgt, i).c_str(),
|
|
||||||
common_token_to_piece(ctx_dft, i).c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_speculative_add_replacement_tgt_dft(
|
|
||||||
struct common_speculative * spec,
|
|
||||||
const char *source, const char *dest) {
|
|
||||||
spec->tgt_dft_replacements[source] = dest;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string replace_to_dft(
|
|
||||||
struct common_speculative * spec,
|
|
||||||
const std::string& input) {
|
|
||||||
std::string result = input;
|
|
||||||
for (const auto & pair : spec->tgt_dft_replacements) {
|
|
||||||
size_t pos = result.find(pair.first);
|
|
||||||
while (pos != std::string::npos) {
|
|
||||||
result.replace(pos, pair.first.length(), pair.second);
|
|
||||||
pos = result.find(pair.first, pos + pair.second.length());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string replace_to_tgt(
|
|
||||||
struct common_speculative * spec,
|
|
||||||
const std::string& input) {
|
|
||||||
std::string result = input;
|
|
||||||
for (const auto& pair : spec->tgt_dft_replacements) {
|
|
||||||
size_t pos = result.find(pair.second);
|
|
||||||
while (pos != std::string::npos) {
|
|
||||||
result.replace(pos, pair.second.length(), pair.first);
|
|
||||||
pos = result.find(pair.second, pos + pair.first.length());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
llama_tokens common_speculative_gen_draft(
|
|
||||||
struct common_speculative * spec,
|
|
||||||
struct common_speculative_params params,
|
|
||||||
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
|
|
||||||
llama_token id_last) {
|
|
||||||
auto & batch = spec->batch;
|
|
||||||
auto & ctx_tgt = spec->ctx_tgt;
|
|
||||||
auto & ctx_dft = spec->ctx_dft;
|
|
||||||
auto & smpl = spec->smpl;
|
|
||||||
auto & prompt_dft = spec->prompt_dft;
|
|
||||||
|
|
||||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
|
||||||
|
|
||||||
int reuse_i = 0;
|
|
||||||
int reuse_n = 0;
|
|
||||||
|
|
||||||
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;
|
|
||||||
|
|
||||||
llama_tokens prompt_tgt_draft_model;
|
|
||||||
if (!spec->vocab_dft_compatible) {
|
|
||||||
std::string text;
|
|
||||||
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
|
|
||||||
text = replace_to_dft(spec, text);
|
|
||||||
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
|
|
||||||
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
|
|
||||||
|
|
||||||
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
|
|
||||||
const auto * model_tgt = llama_get_model(ctx_tgt);
|
|
||||||
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
|
|
||||||
|
|
||||||
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
|
|
||||||
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
|
|
||||||
text.resize(-n_chars);
|
|
||||||
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
|
|
||||||
text = replace_to_dft(spec, text);
|
|
||||||
|
|
||||||
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
|
|
||||||
id_last = common_tokenize(ctx_dft, text, false, true)[0];
|
|
||||||
}
|
|
||||||
// prompt_tgt's tokens will always be compatible with ctx_dft
|
|
||||||
const llama_tokens &prompt_tgt =
|
|
||||||
spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
|
|
||||||
|
|
||||||
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
|
||||||
|
|
||||||
// reuse as much as possible from the old draft context
|
|
||||||
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
|
||||||
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
|
|
||||||
int cur = 0;
|
|
||||||
while (i_start + cur < (int) prompt_tgt.size() &&
|
|
||||||
i + cur < (int) prompt_dft.size() &&
|
|
||||||
prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
|
|
||||||
cur++;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
|
|
||||||
reuse_i = i;
|
|
||||||
reuse_n = cur;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
|
|
||||||
|
|
||||||
llama_tokens result;
|
|
||||||
result.reserve(params.n_draft);
|
|
||||||
|
|
||||||
if (reuse_n == 0) {
|
|
||||||
llama_memory_clear(mem_dft, false);
|
|
||||||
prompt_dft.clear();
|
|
||||||
} else {
|
|
||||||
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
|
||||||
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
|
||||||
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
|
|
||||||
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
|
||||||
result.push_back(prompt_dft[i]);
|
|
||||||
|
|
||||||
if (params.n_draft <= (int) result.size()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (reuse_i > 0) {
|
|
||||||
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
|
||||||
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
|
||||||
|
|
||||||
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (reuse_n < (int) prompt_dft.size()) {
|
|
||||||
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
|
||||||
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare a batch to evaluate any new tokens in the prompt
|
|
||||||
common_batch_clear(batch);
|
|
||||||
|
|
||||||
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
|
|
||||||
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
|
||||||
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
|
||||||
|
|
||||||
prompt_dft.push_back(prompt_tgt[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// we should rarely end-up here during normal decoding
|
|
||||||
if (batch.n_tokens > 0) {
|
|
||||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
|
||||||
|
|
||||||
llama_decode(ctx_dft, batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
const llama_pos n_past = prompt_dft.size();
|
|
||||||
|
|
||||||
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
|
||||||
|
|
||||||
common_batch_clear(batch);
|
|
||||||
common_batch_add (batch, id_last, n_past, { 0 }, true);
|
|
||||||
|
|
||||||
prompt_dft.push_back(id_last);
|
|
||||||
|
|
||||||
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
|
||||||
|
|
||||||
llama_decode(ctx_dft, batch);
|
|
||||||
|
|
||||||
common_sampler_reset(smpl);
|
|
||||||
|
|
||||||
// sample n_draft tokens from the draft model
|
|
||||||
for (int i = 0; i < params.n_draft; ++i) {
|
|
||||||
common_batch_clear(batch);
|
|
||||||
|
|
||||||
common_sampler_sample(smpl, ctx_dft, 0, true);
|
|
||||||
|
|
||||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
|
||||||
|
|
||||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
|
||||||
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
|
||||||
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// add drafted token for each sequence
|
|
||||||
const llama_token id = cur_p->data[0].id;
|
|
||||||
|
|
||||||
common_sampler_accept(smpl, id, true);
|
|
||||||
|
|
||||||
result.push_back(id);
|
|
||||||
|
|
||||||
if (params.n_draft <= (int) result.size()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// only collect very high-confidence draft tokens
|
|
||||||
if (cur_p->data[0].p < params.p_min) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
|
||||||
|
|
||||||
// evaluate the drafted tokens on the draft model
|
|
||||||
llama_decode(ctx_dft, batch);
|
|
||||||
|
|
||||||
prompt_dft.push_back(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!spec->vocab_dft_compatible) {
|
|
||||||
std::string detokenized = common_detokenize(ctx_dft, result, true);
|
|
||||||
detokenized = replace_to_tgt(spec, detokenized);
|
|
||||||
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
|
|
||||||
result = common_tokenize(ctx_tgt, detokenized, false, true);
|
|
||||||
if (result.size() > (size_t)params.n_draft) {
|
|
||||||
result.resize(params.n_draft);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "llama.h"
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
struct common_speculative;
|
|
||||||
|
|
||||||
struct common_speculative_params {
|
|
||||||
int n_draft = 16; // max drafted tokens
|
|
||||||
int n_reuse = 256;
|
|
||||||
|
|
||||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_speculative * common_speculative_init(
|
|
||||||
struct llama_context * ctx_tgt,
|
|
||||||
struct llama_context * ctx_dft
|
|
||||||
);
|
|
||||||
|
|
||||||
void common_speculative_free(struct common_speculative * spec);
|
|
||||||
|
|
||||||
bool common_speculative_are_compatible(
|
|
||||||
const struct llama_context * ctx_tgt,
|
|
||||||
const struct llama_context * ctx_dft);
|
|
||||||
|
|
||||||
void common_speculative_add_replacement_tgt_dft(
|
|
||||||
struct common_speculative * spec,
|
|
||||||
const char *source, const char *dest);
|
|
||||||
|
|
||||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
|
||||||
llama_tokens common_speculative_gen_draft(
|
|
||||||
struct common_speculative * spec,
|
|
||||||
struct common_speculative_params params,
|
|
||||||
const llama_tokens & prompt,
|
|
||||||
llama_token id_last);
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
#include "unicode.h"
|
|
||||||
|
|
||||||
// implementation adopted from src/unicode.cpp
|
|
||||||
|
|
||||||
size_t utf8_sequence_length(unsigned char first_byte) {
|
|
||||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
|
||||||
uint8_t highbits = static_cast<uint8_t>(first_byte) >> 4;
|
|
||||||
return lookup[highbits];
|
|
||||||
}
|
|
||||||
|
|
||||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
|
|
||||||
if (offset >= input.size()) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ASCII fast path
|
|
||||||
if (!(input[offset] & 0x80)) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invalid: continuation byte as first byte
|
|
||||||
if (!(input[offset] & 0x40)) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2-byte sequence
|
|
||||||
if (!(input[offset] & 0x20)) {
|
|
||||||
if (offset + 1 >= input.size()) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
|
||||||
}
|
|
||||||
if ((input[offset + 1] & 0xc0) != 0x80) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
|
||||||
}
|
|
||||||
auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f);
|
|
||||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3-byte sequence
|
|
||||||
if (!(input[offset] & 0x10)) {
|
|
||||||
if (offset + 2 >= input.size()) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
|
||||||
}
|
|
||||||
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
|
||||||
}
|
|
||||||
auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f);
|
|
||||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4-byte sequence
|
|
||||||
if (!(input[offset] & 0x08)) {
|
|
||||||
if (offset + 3 >= input.size()) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
|
||||||
}
|
|
||||||
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) {
|
|
||||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
|
||||||
}
|
|
||||||
auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f);
|
|
||||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invalid first byte
|
|
||||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
|
||||||
}
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <string_view>
|
|
||||||
|
|
||||||
// UTF-8 parsing utilities for streaming-aware unicode support
|
|
||||||
|
|
||||||
struct utf8_parse_result {
|
|
||||||
uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS)
|
|
||||||
size_t bytes_consumed; // How many bytes this codepoint uses (1-4)
|
|
||||||
enum status { SUCCESS, INCOMPLETE, INVALID } status;
|
|
||||||
|
|
||||||
utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0)
|
|
||||||
: codepoint(cp), bytes_consumed(bytes), status(s) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Determine the expected length of a UTF-8 sequence from its first byte
|
|
||||||
// Returns 0 for invalid first bytes
|
|
||||||
size_t utf8_sequence_length(unsigned char first_byte);
|
|
||||||
|
|
||||||
// Parse a single UTF-8 codepoint from input
|
|
||||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,491 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
|
||||||
project("ggml" C CXX ASM)
|
|
||||||
|
|
||||||
### GGML Version
|
|
||||||
set(GGML_VERSION_MAJOR 0)
|
|
||||||
set(GGML_VERSION_MINOR 9)
|
|
||||||
set(GGML_VERSION_PATCH 5)
|
|
||||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
|
||||||
|
|
||||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
|
||||||
if(GIT_EXE)
|
|
||||||
# Get current git commit hash
|
|
||||||
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
|
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
OUTPUT_VARIABLE GGML_BUILD_COMMIT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
ERROR_QUIET
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the working directory is dirty (i.e., has uncommitted changes)
|
|
||||||
execute_process(COMMAND ${GIT_EXE} diff-index --quiet HEAD -- .
|
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
RESULT_VARIABLE GGML_GIT_DIRTY
|
|
||||||
ERROR_QUIET
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(GGML_VERSION "${GGML_VERSION_BASE}")
|
|
||||||
|
|
||||||
if(NOT GGML_BUILD_COMMIT)
|
|
||||||
set(GGML_BUILD_COMMIT "unknown")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Build the commit string with optional dirty flag
|
|
||||||
if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1)
|
|
||||||
set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include(CheckIncludeFileCXX)
|
|
||||||
|
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
||||||
|
|
||||||
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
|
||||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
|
||||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
|
||||||
set(GGML_STANDALONE ON)
|
|
||||||
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
|
||||||
|
|
||||||
# configure project version
|
|
||||||
# TODO
|
|
||||||
else()
|
|
||||||
set(GGML_STANDALONE OFF)
|
|
||||||
|
|
||||||
if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY)
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (EMSCRIPTEN)
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
|
||||||
|
|
||||||
option(GGML_WASM_SINGLE_FILE "ggml: embed WASM inside the generated ggml.js" ON)
|
|
||||||
else()
|
|
||||||
if (MINGW)
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
|
||||||
else()
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# remove the lib prefix on win32 mingw
|
|
||||||
if (WIN32)
|
|
||||||
set(CMAKE_STATIC_LIBRARY_PREFIX "")
|
|
||||||
set(CMAKE_SHARED_LIBRARY_PREFIX "")
|
|
||||||
set(CMAKE_SHARED_MODULE_PREFIX "")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
|
|
||||||
option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF)
|
|
||||||
set(GGML_BACKEND_DIR "" CACHE PATH "ggml: directory to load dynamic backends from (requires GGML_BACKEND_DL")
|
|
||||||
|
|
||||||
#
|
|
||||||
# option list
|
|
||||||
#
|
|
||||||
|
|
||||||
# TODO: mark all options as advanced when not GGML_STANDALONE
|
|
||||||
|
|
||||||
if (APPLE)
|
|
||||||
set(GGML_METAL_DEFAULT ON)
|
|
||||||
set(GGML_BLAS_DEFAULT ON)
|
|
||||||
set(GGML_BLAS_VENDOR_DEFAULT "Apple")
|
|
||||||
else()
|
|
||||||
set(GGML_METAL_DEFAULT OFF)
|
|
||||||
set(GGML_BLAS_DEFAULT OFF)
|
|
||||||
set(GGML_BLAS_VENDOR_DEFAULT "Generic")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH})
|
|
||||||
message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF")
|
|
||||||
set(GGML_NATIVE_DEFAULT OFF)
|
|
||||||
else()
|
|
||||||
set(GGML_NATIVE_DEFAULT ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# defaults
|
|
||||||
if (NOT GGML_LLAMAFILE_DEFAULT)
|
|
||||||
set(GGML_LLAMAFILE_DEFAULT OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NOT GGML_CUDA_GRAPHS_DEFAULT)
|
|
||||||
set(GGML_CUDA_GRAPHS_DEFAULT OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# general
|
|
||||||
option(GGML_STATIC "ggml: static link libraries" OFF)
|
|
||||||
option(GGML_NATIVE "ggml: optimize the build for the current system" ${GGML_NATIVE_DEFAULT})
|
|
||||||
option(GGML_LTO "ggml: enable link time optimization" OFF)
|
|
||||||
option(GGML_CCACHE "ggml: use ccache if available" ON)
|
|
||||||
|
|
||||||
# debug
|
|
||||||
option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON)
|
|
||||||
option(GGML_ALL_WARNINGS_3RD_PARTY "ggml: enable all compiler warnings in 3rd party libs" OFF)
|
|
||||||
option(GGML_GPROF "ggml: enable gprof" OFF)
|
|
||||||
|
|
||||||
# build
|
|
||||||
option(GGML_FATAL_WARNINGS "ggml: enable -Werror flag" OFF)
|
|
||||||
|
|
||||||
# sanitizers
|
|
||||||
option(GGML_SANITIZE_THREAD "ggml: enable thread sanitizer" OFF)
|
|
||||||
option(GGML_SANITIZE_ADDRESS "ggml: enable address sanitizer" OFF)
|
|
||||||
option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
|
|
||||||
|
|
||||||
# instruction set specific
|
|
||||||
if (GGML_NATIVE OR NOT GGML_NATIVE_DEFAULT)
|
|
||||||
set(INS_ENB OFF)
|
|
||||||
else()
|
|
||||||
set(INS_ENB ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
message(DEBUG "GGML_NATIVE : ${GGML_NATIVE}")
|
|
||||||
message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}")
|
|
||||||
message(DEBUG "INS_ENB : ${INS_ENB}")
|
|
||||||
|
|
||||||
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
|
||||||
option(GGML_CPU_REPACK "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
|
||||||
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
|
|
||||||
option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB})
|
|
||||||
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
|
||||||
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
|
|
||||||
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
|
|
||||||
option(GGML_BMI2 "ggml: enable BMI2" ${INS_ENB})
|
|
||||||
option(GGML_AVX512 "ggml: enable AVX512F" OFF)
|
|
||||||
option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
|
|
||||||
option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
|
|
||||||
option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF)
|
|
||||||
if (NOT MSVC)
|
|
||||||
# in MSVC F16C and FMA is implied with AVX2/AVX512
|
|
||||||
option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
|
|
||||||
option(GGML_F16C "ggml: enable F16C" ${INS_ENB})
|
|
||||||
# MSVC does not seem to support AMX
|
|
||||||
option(GGML_AMX_TILE "ggml: enable AMX-TILE" OFF)
|
|
||||||
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
|
|
||||||
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
|
|
||||||
endif()
|
|
||||||
option(GGML_LASX "ggml: enable lasx" ON)
|
|
||||||
option(GGML_LSX "ggml: enable lsx" ON)
|
|
||||||
option(GGML_RVV "ggml: enable rvv" ON)
|
|
||||||
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
|
||||||
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
|
||||||
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
|
||||||
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON)
|
|
||||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
|
||||||
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
|
||||||
|
|
||||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
|
||||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
|
||||||
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
|
|
||||||
|
|
||||||
# ggml core
|
|
||||||
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
|
||||||
option(GGML_CPU "ggml: enable CPU backend" ON)
|
|
||||||
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
|
|
||||||
|
|
||||||
# 3rd party libs / backends
|
|
||||||
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
|
|
||||||
option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT})
|
|
||||||
set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
|
|
||||||
"ggml: BLAS library vendor")
|
|
||||||
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" ${GGML_LLAMAFILE_DEFAULT})
|
|
||||||
|
|
||||||
option(GGML_CUDA "ggml: use CUDA" OFF)
|
|
||||||
option(GGML_MUSA "ggml: use MUSA" OFF)
|
|
||||||
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
|
|
||||||
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
|
|
||||||
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
|
||||||
"ggml: max. batch size for using peer access")
|
|
||||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
|
||||||
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
|
||||||
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
|
|
||||||
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
|
||||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
|
||||||
set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING
|
|
||||||
"ggml: cuda link binary compression mode; requires cuda 12.8+")
|
|
||||||
set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size")
|
|
||||||
|
|
||||||
option(GGML_HIP "ggml: use HIP" OFF)
|
|
||||||
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
|
||||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
|
||||||
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
|
||||||
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
|
|
||||||
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
|
|
||||||
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
|
|
||||||
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
|
|
||||||
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
|
||||||
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
|
||||||
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
|
|
||||||
option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF)
|
|
||||||
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
|
|
||||||
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
|
|
||||||
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
|
||||||
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
|
|
||||||
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
|
||||||
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
|
|
||||||
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
|
|
||||||
option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
|
|
||||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
|
||||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
|
||||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
|
||||||
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
|
|
||||||
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
|
|
||||||
set (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING
|
|
||||||
"ggml: metal minimum macOS version")
|
|
||||||
set (GGML_METAL_STD "" CACHE STRING "ggml: metal standard version (-std flag)")
|
|
||||||
option(GGML_OPENMP "ggml: use OpenMP" ON)
|
|
||||||
option(GGML_RPC "ggml: use RPC" OFF)
|
|
||||||
option(GGML_SYCL "ggml: use SYCL" OFF)
|
|
||||||
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
|
||||||
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
|
||||||
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
|
|
||||||
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
|
||||||
"ggml: sycl target device")
|
|
||||||
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
|
||||||
"ggml: sycl device architecture")
|
|
||||||
|
|
||||||
option(GGML_OPENCL "ggml: use OpenCL" OFF)
|
|
||||||
option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF)
|
|
||||||
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
|
|
||||||
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
|
|
||||||
set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
|
||||||
"gmml: OpenCL API version to target")
|
|
||||||
|
|
||||||
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
|
|
||||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
|
|
||||||
|
|
||||||
# toolchain for vulkan-shaders-gen
|
|
||||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
|
||||||
|
|
||||||
option(GGML_ZENDNN "ggml: use ZenDNN" OFF)
|
|
||||||
option(ZENDNN_ROOT "ggml: path to ZenDNN installation" "")
|
|
||||||
|
|
||||||
# extra artifacts
|
|
||||||
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
|
|
||||||
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
|
|
||||||
|
|
||||||
#
|
|
||||||
# dependencies
|
|
||||||
#
|
|
||||||
|
|
||||||
set(CMAKE_C_STANDARD 11)
|
|
||||||
set(CMAKE_C_STANDARD_REQUIRED true)
|
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
|
||||||
|
|
||||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
|
||||||
|
|
||||||
find_package(Threads REQUIRED)
|
|
||||||
|
|
||||||
include(GNUInstallDirs)
|
|
||||||
|
|
||||||
#
|
|
||||||
# build the library
|
|
||||||
#
|
|
||||||
|
|
||||||
add_subdirectory(src)
|
|
||||||
|
|
||||||
#
|
|
||||||
# tests and examples
|
|
||||||
#
|
|
||||||
|
|
||||||
if (GGML_BUILD_TESTS)
|
|
||||||
enable_testing()
|
|
||||||
add_subdirectory(tests)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
if (GGML_BUILD_EXAMPLES)
|
|
||||||
add_subdirectory(examples)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
#
|
|
||||||
# install
|
|
||||||
#
|
|
||||||
|
|
||||||
include(CMakePackageConfigHelpers)
|
|
||||||
|
|
||||||
# all public headers
|
|
||||||
set(GGML_PUBLIC_HEADERS
|
|
||||||
include/ggml.h
|
|
||||||
include/ggml-cpu.h
|
|
||||||
include/ggml-alloc.h
|
|
||||||
include/ggml-backend.h
|
|
||||||
include/ggml-blas.h
|
|
||||||
include/ggml-cann.h
|
|
||||||
include/ggml-cpp.h
|
|
||||||
include/ggml-cuda.h
|
|
||||||
include/ggml-opt.h
|
|
||||||
include/ggml-metal.h
|
|
||||||
include/ggml-rpc.h
|
|
||||||
include/ggml-sycl.h
|
|
||||||
include/ggml-vulkan.h
|
|
||||||
include/ggml-webgpu.h
|
|
||||||
include/ggml-zendnn.h
|
|
||||||
include/gguf.h)
|
|
||||||
|
|
||||||
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
|
||||||
#if (GGML_METAL)
|
|
||||||
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
|
|
||||||
#endif()
|
|
||||||
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
|
|
||||||
install(TARGETS ggml-base LIBRARY)
|
|
||||||
|
|
||||||
if (GGML_STANDALONE)
|
|
||||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
|
|
||||||
@ONLY)
|
|
||||||
|
|
||||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
|
|
||||||
DESTINATION share/pkgconfig)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
|
||||||
# Create CMake package
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Capture variables prefixed with GGML_.
|
|
||||||
|
|
||||||
set(variable_set_statements
|
|
||||||
"
|
|
||||||
####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() #######
|
|
||||||
####### Any changes to this file will be overwritten by the next CMake run #######
|
|
||||||
|
|
||||||
")
|
|
||||||
|
|
||||||
set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS})
|
|
||||||
|
|
||||||
get_cmake_property(all_variables VARIABLES)
|
|
||||||
foreach(variable_name IN LISTS all_variables)
|
|
||||||
if(variable_name MATCHES "^GGML_")
|
|
||||||
string(REPLACE ";" "\\;"
|
|
||||||
variable_value "${${variable_name}}")
|
|
||||||
|
|
||||||
set(variable_set_statements
|
|
||||||
"${variable_set_statements}set(${variable_name} \"${variable_value}\")\n")
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
set(GGML_VARIABLES_EXPANDED ${variable_set_statements})
|
|
||||||
|
|
||||||
# Create the CMake package and set install location.
|
|
||||||
|
|
||||||
set(GGML_INSTALL_VERSION ${GGML_VERSION})
|
|
||||||
set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
|
||||||
set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
|
||||||
set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
|
||||||
|
|
||||||
configure_package_config_file(
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
|
||||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml
|
|
||||||
PATH_VARS GGML_INCLUDE_INSTALL_DIR
|
|
||||||
GGML_LIB_INSTALL_DIR
|
|
||||||
GGML_BIN_INSTALL_DIR)
|
|
||||||
|
|
||||||
write_basic_package_version_file(
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
|
||||||
VERSION ${GGML_INSTALL_VERSION}
|
|
||||||
COMPATIBILITY SameMajorVersion)
|
|
||||||
|
|
||||||
target_compile_definitions(ggml-base PRIVATE
|
|
||||||
GGML_VERSION="${GGML_INSTALL_VERSION}"
|
|
||||||
GGML_COMMIT="${GGML_BUILD_COMMIT}"
|
|
||||||
)
|
|
||||||
message(STATUS "ggml version: ${GGML_INSTALL_VERSION}")
|
|
||||||
message(STATUS "ggml commit: ${GGML_BUILD_COMMIT}")
|
|
||||||
|
|
||||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
|
||||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
|
|
||||||
|
|
||||||
if (MSVC)
|
|
||||||
set(MSVC_WARNING_FLAGS
|
|
||||||
/wd4005 # Macro redefinition
|
|
||||||
/wd4244 # Conversion from one type to another type, possible loss of data
|
|
||||||
/wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
|
|
||||||
/wd4305 # Conversion from 'type1' to 'type2', possible loss of data
|
|
||||||
/wd4566 # Conversion from 'char' to 'wchar_t', possible loss of data
|
|
||||||
/wd4996 # Disable POSIX deprecation warnings
|
|
||||||
/wd4702 # Unreachable code warnings
|
|
||||||
)
|
|
||||||
set(MSVC_COMPILE_OPTIONS
|
|
||||||
"$<$<COMPILE_LANGUAGE:C>:/utf-8>"
|
|
||||||
"$<$<COMPILE_LANGUAGE:CXX>:/utf-8>"
|
|
||||||
)
|
|
||||||
function(configure_msvc_target target_name)
|
|
||||||
if(TARGET ${target_name})
|
|
||||||
target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
|
|
||||||
target_compile_options(${target_name} PRIVATE ${MSVC_COMPILE_OPTIONS})
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
configure_msvc_target(ggml-base)
|
|
||||||
configure_msvc_target(ggml)
|
|
||||||
configure_msvc_target(ggml-cpu)
|
|
||||||
configure_msvc_target(ggml-cpu-x64)
|
|
||||||
configure_msvc_target(ggml-cpu-sse42)
|
|
||||||
configure_msvc_target(ggml-cpu-sandybridge)
|
|
||||||
# __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
|
||||||
# skipping ggml-cpu-ivybridge
|
|
||||||
# skipping ggml-cpu-piledriver
|
|
||||||
configure_msvc_target(ggml-cpu-haswell)
|
|
||||||
configure_msvc_target(ggml-cpu-skylakex)
|
|
||||||
configure_msvc_target(ggml-cpu-cannonlake)
|
|
||||||
configure_msvc_target(ggml-cpu-cascadelake)
|
|
||||||
configure_msvc_target(ggml-cpu-icelake)
|
|
||||||
# MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
|
|
||||||
# https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
|
|
||||||
# https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
|
|
||||||
# skipping ggml-cpu-cooperlake
|
|
||||||
# skipping ggml-cpu-zen4
|
|
||||||
configure_msvc_target(ggml-cpu-alderlake)
|
|
||||||
# MSVC doesn't support AMX
|
|
||||||
# skipping ggml-cpu-sapphirerapids
|
|
||||||
|
|
||||||
if (GGML_BUILD_EXAMPLES)
|
|
||||||
configure_msvc_target(common-ggml)
|
|
||||||
configure_msvc_target(common)
|
|
||||||
|
|
||||||
configure_msvc_target(mnist-common)
|
|
||||||
configure_msvc_target(mnist-eval)
|
|
||||||
configure_msvc_target(mnist-train)
|
|
||||||
|
|
||||||
configure_msvc_target(gpt-2-ctx)
|
|
||||||
configure_msvc_target(gpt-2-alloc)
|
|
||||||
configure_msvc_target(gpt-2-backend)
|
|
||||||
configure_msvc_target(gpt-2-sched)
|
|
||||||
configure_msvc_target(gpt-2-quantize)
|
|
||||||
configure_msvc_target(gpt-2-batched)
|
|
||||||
|
|
||||||
configure_msvc_target(gpt-j)
|
|
||||||
configure_msvc_target(gpt-j-quantize)
|
|
||||||
|
|
||||||
configure_msvc_target(magika)
|
|
||||||
configure_msvc_target(yolov3-tiny)
|
|
||||||
configure_msvc_target(sam)
|
|
||||||
|
|
||||||
configure_msvc_target(simple-ctx)
|
|
||||||
configure_msvc_target(simple-backend)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_BUILD_TESTS)
|
|
||||||
configure_msvc_target(test-mul-mat)
|
|
||||||
configure_msvc_target(test-arange)
|
|
||||||
configure_msvc_target(test-backend-ops)
|
|
||||||
configure_msvc_target(test-cont)
|
|
||||||
configure_msvc_target(test-conv-transpose)
|
|
||||||
configure_msvc_target(test-conv-transpose-1d)
|
|
||||||
configure_msvc_target(test-conv1d)
|
|
||||||
configure_msvc_target(test-conv2d)
|
|
||||||
configure_msvc_target(test-conv2d-dw)
|
|
||||||
configure_msvc_target(test-customop)
|
|
||||||
configure_msvc_target(test-dup)
|
|
||||||
configure_msvc_target(test-opt)
|
|
||||||
configure_msvc_target(test-pool)
|
|
||||||
endif ()
|
|
||||||
endif()
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
find_package(Git)
|
|
||||||
|
|
||||||
# the commit's SHA1
|
|
||||||
execute_process(COMMAND
|
|
||||||
"${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
|
|
||||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
|
||||||
OUTPUT_VARIABLE GIT_SHA1
|
|
||||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
|
||||||
|
|
||||||
# the date of the commit
|
|
||||||
execute_process(COMMAND
|
|
||||||
"${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
|
|
||||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
|
||||||
OUTPUT_VARIABLE GIT_DATE
|
|
||||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
|
||||||
|
|
||||||
# the subject of the commit
|
|
||||||
execute_process(COMMAND
|
|
||||||
"${GIT_EXECUTABLE}" log -1 --format=%s
|
|
||||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
|
||||||
OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
|
|
||||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
function(ggml_get_flags CCID CCVER)
|
|
||||||
set(C_FLAGS "")
|
|
||||||
set(CXX_FLAGS "")
|
|
||||||
|
|
||||||
if (CCID MATCHES "Clang")
|
|
||||||
set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
|
|
||||||
set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
|
|
||||||
|
|
||||||
if (
|
|
||||||
(CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
|
|
||||||
(CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
|
|
||||||
)
|
|
||||||
list(APPEND C_FLAGS -Wdouble-promotion)
|
|
||||||
endif()
|
|
||||||
elseif (CCID STREQUAL "GNU")
|
|
||||||
set(C_FLAGS -Wdouble-promotion)
|
|
||||||
set(CXX_FLAGS -Wno-array-bounds)
|
|
||||||
|
|
||||||
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
|
|
||||||
list(APPEND CXX_FLAGS -Wextra-semi)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
|
|
||||||
set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
function(ggml_get_system_arch)
|
|
||||||
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
|
|
||||||
CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
|
|
||||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
|
||||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
|
|
||||||
set(GGML_SYSTEM_ARCH "ARM" PARENT_SCOPE)
|
|
||||||
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR
|
|
||||||
CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
|
||||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
|
||||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$"))
|
|
||||||
set(GGML_SYSTEM_ARCH "x86" PARENT_SCOPE)
|
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc|power")
|
|
||||||
set(GGML_SYSTEM_ARCH "PowerPC" PARENT_SCOPE)
|
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
|
||||||
set(GGML_SYSTEM_ARCH "loongarch64" PARENT_SCOPE)
|
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64")
|
|
||||||
set(GGML_SYSTEM_ARCH "riscv64" PARENT_SCOPE)
|
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
|
||||||
set(GGML_SYSTEM_ARCH "s390x" PARENT_SCOPE)
|
|
||||||
else()
|
|
||||||
set(GGML_SYSTEM_ARCH "UNKNOWN" PARENT_SCOPE)
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
@PACKAGE_INIT@
|
|
||||||
|
|
||||||
@GGML_VARIABLES_EXPANDED@
|
|
||||||
|
|
||||||
# Find all dependencies before creating any target.
|
|
||||||
include(CMakeFindDependencyMacro)
|
|
||||||
find_dependency(Threads)
|
|
||||||
if (NOT GGML_SHARED_LIB)
|
|
||||||
set(GGML_CPU_INTERFACE_LINK_LIBRARIES "")
|
|
||||||
set(GGML_CPU_INTERFACE_LINK_OPTIONS "")
|
|
||||||
|
|
||||||
if (APPLE AND GGML_ACCELERATE)
|
|
||||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
|
||||||
if(NOT ACCELERATE_FRAMEWORK)
|
|
||||||
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
|
||||||
return()
|
|
||||||
endif()
|
|
||||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_OPENMP_ENABLED)
|
|
||||||
find_dependency(OpenMP)
|
|
||||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_CPU_HBM)
|
|
||||||
find_library(memkind memkind)
|
|
||||||
if(NOT memkind)
|
|
||||||
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
|
||||||
return()
|
|
||||||
endif()
|
|
||||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_BLAS)
|
|
||||||
find_dependency(BLAS)
|
|
||||||
list(APPEND GGML_BLAS_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
|
|
||||||
list(APPEND GGML_BLAS_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_CUDA)
|
|
||||||
set(GGML_CUDA_INTERFACE_LINK_LIBRARIES "")
|
|
||||||
find_dependency(CUDAToolkit)
|
|
||||||
if (GGML_STATIC)
|
|
||||||
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cudart_static>)
|
|
||||||
if (WIN32)
|
|
||||||
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas> $<LINK_ONLY:CUDA::cublasLt>)
|
|
||||||
else()
|
|
||||||
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas_static> $<LINK_ONLY:CUDA::cublasLt_static>)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
if (NOT GGML_CUDA_NO_VMM)
|
|
||||||
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cuda_driver>)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_METAL)
|
|
||||||
find_library(FOUNDATION_LIBRARY Foundation)
|
|
||||||
find_library(METAL_FRAMEWORK Metal)
|
|
||||||
find_library(METALKIT_FRAMEWORK MetalKit)
|
|
||||||
if(NOT FOUNDATION_LIBRARY OR NOT METAL_FRAMEWORK OR NOT METALKIT_FRAMEWORK)
|
|
||||||
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
|
||||||
return()
|
|
||||||
endif()
|
|
||||||
set(GGML_METAL_INTERFACE_LINK_LIBRARIES
|
|
||||||
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_OPENCL)
|
|
||||||
find_dependency(OpenCL)
|
|
||||||
set(GGML_OPENCL_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:OpenCL::OpenCL>)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_VULKAN)
|
|
||||||
find_dependency(Vulkan)
|
|
||||||
set(GGML_VULKAN_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:Vulkan::Vulkan>)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_HIP)
|
|
||||||
find_dependency(hip)
|
|
||||||
find_dependency(hipblas)
|
|
||||||
find_dependency(rocblas)
|
|
||||||
set(GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_SYCL)
|
|
||||||
set(GGML_SYCL_INTERFACE_LINK_LIBRARIES "")
|
|
||||||
find_package(DNNL)
|
|
||||||
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
|
||||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
|
|
||||||
endif()
|
|
||||||
if (WIN32)
|
|
||||||
find_dependency(IntelSYCL)
|
|
||||||
find_dependency(MKL)
|
|
||||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
|
|
||||||
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
|
|
||||||
#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
|
||||||
|
|
||||||
if(NOT TARGET ggml::ggml)
|
|
||||||
find_package(Threads REQUIRED)
|
|
||||||
|
|
||||||
find_library(GGML_LIBRARY ggml
|
|
||||||
REQUIRED
|
|
||||||
HINTS ${GGML_LIB_DIR}
|
|
||||||
NO_CMAKE_FIND_ROOT_PATH)
|
|
||||||
|
|
||||||
add_library(ggml::ggml UNKNOWN IMPORTED)
|
|
||||||
set_target_properties(ggml::ggml
|
|
||||||
PROPERTIES
|
|
||||||
IMPORTED_LOCATION "${GGML_LIBRARY}")
|
|
||||||
|
|
||||||
find_library(GGML_BASE_LIBRARY ggml-base
|
|
||||||
REQUIRED
|
|
||||||
HINTS ${GGML_LIB_DIR}
|
|
||||||
NO_CMAKE_FIND_ROOT_PATH)
|
|
||||||
|
|
||||||
add_library(ggml::ggml-base UNKNOWN IMPORTED)
|
|
||||||
set_target_properties(ggml::ggml-base
|
|
||||||
PROPERTIES
|
|
||||||
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
|
|
||||||
|
|
||||||
set(_ggml_all_targets "")
|
|
||||||
if (NOT GGML_BACKEND_DL)
|
|
||||||
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
|
|
||||||
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
|
|
||||||
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
|
|
||||||
|
|
||||||
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
|
|
||||||
REQUIRED
|
|
||||||
HINTS ${GGML_LIB_DIR}
|
|
||||||
NO_CMAKE_FIND_ROOT_PATH)
|
|
||||||
|
|
||||||
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
|
|
||||||
|
|
||||||
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
|
|
||||||
set_target_properties(ggml::${_ggml_backend}
|
|
||||||
PROPERTIES
|
|
||||||
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
|
|
||||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
|
||||||
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
|
|
||||||
INTERFACE_COMPILE_FEATURES c_std_90
|
|
||||||
POSITION_INDEPENDENT_CODE ON)
|
|
||||||
|
|
||||||
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
|
|
||||||
if(is_cpu_variant)
|
|
||||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
|
|
||||||
set_target_properties(ggml::${_ggml_backend}
|
|
||||||
PROPERTIES
|
|
||||||
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
|
|
||||||
|
|
||||||
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
|
|
||||||
set_target_properties(ggml::${_ggml_backend}
|
|
||||||
PROPERTIES
|
|
||||||
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
else()
|
|
||||||
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
|
|
||||||
set_target_properties(ggml::${_ggml_backend}
|
|
||||||
PROPERTIES
|
|
||||||
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
|
|
||||||
|
|
||||||
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
|
|
||||||
set_target_properties(ggml::${_ggml_backend}
|
|
||||||
PROPERTIES
|
|
||||||
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
|
|
||||||
endforeach()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
|
|
||||||
set_target_properties(ggml::ggml
|
|
||||||
PROPERTIES
|
|
||||||
INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}")
|
|
||||||
|
|
||||||
add_library(ggml::all INTERFACE IMPORTED)
|
|
||||||
set_target_properties(ggml::all
|
|
||||||
PROPERTIES
|
|
||||||
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
|
|
||||||
|
|
||||||
endif()
|
|
||||||
|
|
||||||
check_required_components(ggml)
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
|
|
||||||
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
|
|
||||||
typedef struct ggml_backend * ggml_backend_t;
|
|
||||||
|
|
||||||
// Tensor allocator
|
|
||||||
struct ggml_tallocr {
|
|
||||||
ggml_backend_buffer_t buffer;
|
|
||||||
void * base;
|
|
||||||
size_t alignment;
|
|
||||||
size_t offset;
|
|
||||||
};
|
|
||||||
|
|
||||||
GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
|
|
||||||
|
|
||||||
// Graph allocator
|
|
||||||
/*
|
|
||||||
Example usage:
|
|
||||||
ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
|
|
||||||
|
|
||||||
// optional: create a worst-case graph and reserve the buffers to avoid reallocations
|
|
||||||
ggml_gallocr_reserve(galloc, build_graph(max_batch));
|
|
||||||
|
|
||||||
// allocate the graph
|
|
||||||
struct ggml_cgraph * graph = build_graph(batch);
|
|
||||||
ggml_gallocr_alloc_graph(galloc, graph);
|
|
||||||
|
|
||||||
printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
|
|
||||||
|
|
||||||
// evaluate the graph
|
|
||||||
ggml_backend_graph_compute(backend, graph);
|
|
||||||
*/
|
|
||||||
|
|
||||||
// special tensor flags for use with the graph allocator:
|
|
||||||
// ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
|
|
||||||
// ggml_set_output(): output tensors are never freed and never overwritten
|
|
||||||
|
|
||||||
typedef struct ggml_gallocr * ggml_gallocr_t;
|
|
||||||
|
|
||||||
GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
|
|
||||||
GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
|
|
||||||
|
|
||||||
// pre-allocate buffers from a measure graph - does not allocate or modify the graph
|
|
||||||
// call with a worst-case graph to avoid buffer reallocations
|
|
||||||
// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
|
|
||||||
// returns false if the buffer allocation failed
|
|
||||||
// ggml_gallocr_resrve_n_size writes the buffer sizes per galloc buffer that would be allocated by ggml_gallocr_reserve_n to sizes
|
|
||||||
GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
|
|
||||||
GGML_API void ggml_gallocr_reserve_n_size(
|
|
||||||
ggml_gallocr_t galloc,
|
|
||||||
struct ggml_cgraph * graph,
|
|
||||||
const int * node_buffer_ids,
|
|
||||||
const int * leaf_buffer_ids,
|
|
||||||
size_t * sizes);
|
|
||||||
GGML_API bool ggml_gallocr_reserve_n(
|
|
||||||
ggml_gallocr_t galloc,
|
|
||||||
struct ggml_cgraph * graph,
|
|
||||||
const int * node_buffer_ids,
|
|
||||||
const int * leaf_buffer_ids);
|
|
||||||
|
|
||||||
// automatic reallocation if the topology changes when using a single buffer
|
|
||||||
// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
|
|
||||||
GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
|
|
||||||
|
|
||||||
GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
|
||||||
|
|
||||||
// Utils
|
|
||||||
// Create a buffer and allocate all the tensors in a ggml_context
|
|
||||||
// ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft
|
|
||||||
GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,373 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-alloc.h"
|
|
||||||
|
|
||||||
#ifdef GGML_BACKEND_SHARED
|
|
||||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
|
||||||
# ifdef GGML_BACKEND_BUILD
|
|
||||||
# define GGML_BACKEND_API __declspec(dllexport) extern
|
|
||||||
# else
|
|
||||||
# define GGML_BACKEND_API __declspec(dllimport) extern
|
|
||||||
# endif
|
|
||||||
# else
|
|
||||||
# define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern
|
|
||||||
# endif
|
|
||||||
#else
|
|
||||||
# define GGML_BACKEND_API extern
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
|
|
||||||
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
|
|
||||||
typedef struct ggml_backend_event * ggml_backend_event_t;
|
|
||||||
typedef struct ggml_backend * ggml_backend_t;
|
|
||||||
typedef void * ggml_backend_graph_plan_t;
|
|
||||||
typedef struct ggml_backend_reg * ggml_backend_reg_t;
|
|
||||||
typedef struct ggml_backend_device * ggml_backend_dev_t;
|
|
||||||
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend buffer type
|
|
||||||
//
|
|
||||||
|
|
||||||
GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
|
|
||||||
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
|
|
||||||
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend buffer
|
|
||||||
//
|
|
||||||
|
|
||||||
enum ggml_backend_buffer_usage {
|
|
||||||
GGML_BACKEND_BUFFER_USAGE_ANY = 0,
|
|
||||||
GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
|
|
||||||
GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
|
||||||
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor);
|
|
||||||
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
|
|
||||||
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
|
|
||||||
GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
|
|
||||||
|
|
||||||
// tensor copy between different backends
|
|
||||||
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend (stream)
|
|
||||||
//
|
|
||||||
|
|
||||||
GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend);
|
|
||||||
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
|
|
||||||
GGML_API void ggml_backend_free(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
|
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
|
|
||||||
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
|
|
||||||
GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
|
||||||
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
|
||||||
|
|
||||||
// "offset" refers to the offset in tensor->data for setting/getting data
|
|
||||||
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
|
||||||
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
|
||||||
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
|
||||||
|
|
||||||
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
|
||||||
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
|
||||||
|
|
||||||
GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
|
||||||
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
|
||||||
GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
|
||||||
|
|
||||||
// NOTE: will be removed, use device version instead
|
|
||||||
GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
|
|
||||||
GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
|
|
||||||
|
|
||||||
// asynchronous copy
|
|
||||||
// the copy is performed after all the currently queued operations in backend_src
|
|
||||||
// backend_dst will wait for the copy to complete before performing other operations
|
|
||||||
// automatic fallback to sync copy if async is not supported
|
|
||||||
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
|
|
||||||
|
|
||||||
GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Events
|
|
||||||
//
|
|
||||||
|
|
||||||
GGML_API ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device);
|
|
||||||
GGML_API void ggml_backend_event_free(ggml_backend_event_t event);
|
|
||||||
GGML_API void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend);
|
|
||||||
GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event);
|
|
||||||
GGML_API void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend device
|
|
||||||
//
|
|
||||||
|
|
||||||
enum ggml_backend_dev_type {
|
|
||||||
// CPU device using system memory
|
|
||||||
GGML_BACKEND_DEVICE_TYPE_CPU,
|
|
||||||
// GPU device using dedicated memory
|
|
||||||
GGML_BACKEND_DEVICE_TYPE_GPU,
|
|
||||||
// integrated GPU device using host memory
|
|
||||||
GGML_BACKEND_DEVICE_TYPE_IGPU,
|
|
||||||
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
|
|
||||||
GGML_BACKEND_DEVICE_TYPE_ACCEL
|
|
||||||
};
|
|
||||||
|
|
||||||
// functionality supported by the device
|
|
||||||
struct ggml_backend_dev_caps {
|
|
||||||
// asynchronous operations
|
|
||||||
bool async;
|
|
||||||
// pinned host buffer
|
|
||||||
bool host_buffer;
|
|
||||||
// creating buffers from host ptr
|
|
||||||
bool buffer_from_host_ptr;
|
|
||||||
// event synchronization
|
|
||||||
bool events;
|
|
||||||
};
|
|
||||||
|
|
||||||
// all the device properties
|
|
||||||
struct ggml_backend_dev_props {
|
|
||||||
// device name
|
|
||||||
const char * name;
|
|
||||||
// device description
|
|
||||||
const char * description;
|
|
||||||
// device free memory in bytes
|
|
||||||
size_t memory_free;
|
|
||||||
// device total memory in bytes
|
|
||||||
size_t memory_total;
|
|
||||||
// device type
|
|
||||||
enum ggml_backend_dev_type type;
|
|
||||||
// device id
|
|
||||||
// for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
|
|
||||||
// if the id is unknown, this should be NULL
|
|
||||||
const char * device_id;
|
|
||||||
// device capabilities
|
|
||||||
struct ggml_backend_dev_caps caps;
|
|
||||||
};
|
|
||||||
|
|
||||||
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
|
|
||||||
GGML_API const char * ggml_backend_dev_description(ggml_backend_dev_t device);
|
|
||||||
GGML_API void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total);
|
|
||||||
GGML_API enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device);
|
|
||||||
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
|
||||||
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
|
||||||
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
|
||||||
|
|
||||||
GGML_API bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op);
|
|
||||||
GGML_API bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend (reg)
|
|
||||||
//
|
|
||||||
|
|
||||||
GGML_API const char * ggml_backend_reg_name(ggml_backend_reg_t reg);
|
|
||||||
GGML_API size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg);
|
|
||||||
GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index);
|
|
||||||
GGML_API void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name);
|
|
||||||
|
|
||||||
// Common functions that may be obtained using ggml_backend_reg_get_proc_address
|
|
||||||
|
|
||||||
// Split buffer type for tensor parallelism
|
|
||||||
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
|
|
||||||
// Set the number of threads for the backend
|
|
||||||
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);
|
|
||||||
// Get additional buffer types provided by the device (returns a NULL-terminated array)
|
|
||||||
typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device);
|
|
||||||
// Set the abort callback for the backend
|
|
||||||
typedef void (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data);
|
|
||||||
// Get a list of feature flags supported by the backend (returns a NULL-terminated array)
|
|
||||||
struct ggml_backend_feature {
|
|
||||||
const char * name;
|
|
||||||
const char * value;
|
|
||||||
};
|
|
||||||
typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend registry
|
|
||||||
//
|
|
||||||
|
|
||||||
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
|
|
||||||
|
|
||||||
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
|
|
||||||
|
|
||||||
// Backend (reg) enumeration
|
|
||||||
GGML_API size_t ggml_backend_reg_count(void);
|
|
||||||
GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
|
|
||||||
GGML_API ggml_backend_reg_t ggml_backend_reg_by_name(const char * name);
|
|
||||||
|
|
||||||
// Device enumeration
|
|
||||||
GGML_API size_t ggml_backend_dev_count(void);
|
|
||||||
GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index);
|
|
||||||
GGML_API ggml_backend_dev_t ggml_backend_dev_by_name(const char * name);
|
|
||||||
GGML_API ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type);
|
|
||||||
|
|
||||||
// Direct backend (stream) initialization
|
|
||||||
// = ggml_backend_dev_init(ggml_backend_dev_by_name(name), params)
|
|
||||||
GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params);
|
|
||||||
// = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params)
|
|
||||||
GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params);
|
|
||||||
// = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL)
|
|
||||||
GGML_API ggml_backend_t ggml_backend_init_best(void);
|
|
||||||
|
|
||||||
// Load a backend from a dynamic library and register it
|
|
||||||
GGML_API ggml_backend_reg_t ggml_backend_load(const char * path);
|
|
||||||
// Unload a backend if loaded dynamically and unregister it
|
|
||||||
GGML_API void ggml_backend_unload(ggml_backend_reg_t reg);
|
|
||||||
// Load all known backends from dynamic libraries
|
|
||||||
GGML_API void ggml_backend_load_all(void);
|
|
||||||
GGML_API void ggml_backend_load_all_from_path(const char * dir_path);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend scheduler
|
|
||||||
//
|
|
||||||
|
|
||||||
// The backend scheduler allows for multiple backend devices to be used together
|
|
||||||
// Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
|
|
||||||
// The backends are selected based on:
|
|
||||||
// - the backend that supports the operation
|
|
||||||
// - the location of the pre-allocated tensors (e.g. the weights)
|
|
||||||
/*
|
|
||||||
Example usage:
|
|
||||||
|
|
||||||
// operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
|
|
||||||
// preferrably to run on the same backend as the buffer
|
|
||||||
ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
|
||||||
|
|
||||||
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true);
|
|
||||||
|
|
||||||
// initialize buffers from a max size graph (optional)
|
|
||||||
reserve_graph = build_graph(sched, max_batch_size);
|
|
||||||
|
|
||||||
// manually assign nodes to a backend (optional, should not be needed in most cases)
|
|
||||||
struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
|
|
||||||
ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
|
|
||||||
|
|
||||||
ggml_backend_sched_reserve(sched, reserve_graph);
|
|
||||||
|
|
||||||
// compute
|
|
||||||
graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation
|
|
||||||
for (int i = 0; i < 10; ++i) {
|
|
||||||
ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically
|
|
||||||
}
|
|
||||||
|
|
||||||
// if there are graph inputs:
|
|
||||||
graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called)
|
|
||||||
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
|
|
||||||
ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it
|
|
||||||
ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors
|
|
||||||
ggml_backend_sched_graph_compute(sched, graph); // execute the graph
|
|
||||||
|
|
||||||
// as an alternative to the above it is also possible to assign the inputs to a dedicated context and
|
|
||||||
// allocate them statically via ggml_backend_alloc_ctx_tensors
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
typedef struct ggml_backend_sched * ggml_backend_sched_t;
|
|
||||||
|
|
||||||
// Evaluation callback for each node in the graph (set with ggml_backend_sched_set_eval_callback)
|
|
||||||
// when ask == true, the scheduler wants to know if the user wants to observe this node
|
|
||||||
// this allows the scheduler to batch nodes together in order to evaluate them in a single call
|
|
||||||
//
|
|
||||||
// when ask == false, the scheduler is passing the node tensor to the user for observation
|
|
||||||
// if the user returns false, the scheduler will cancel the graph compute
|
|
||||||
//
|
|
||||||
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
|
||||||
|
|
||||||
// Initialize a backend scheduler, backends with low index are given priority over backends with high index
|
|
||||||
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
|
|
||||||
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
|
||||||
|
|
||||||
// Initialize backend buffers from a measure graph
|
|
||||||
GGML_API void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes);
|
|
||||||
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
|
|
||||||
|
|
||||||
GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
|
|
||||||
GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
|
|
||||||
|
|
||||||
// Get the number of splits of the last graph
|
|
||||||
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
|
|
||||||
GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
|
|
||||||
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend);
|
|
||||||
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
|
||||||
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
|
||||||
|
|
||||||
// Split graph without allocating it
|
|
||||||
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
|
||||||
|
|
||||||
// Allocate and compute graph on the backend scheduler
|
|
||||||
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
|
|
||||||
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
|
||||||
GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
|
||||||
GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
|
|
||||||
|
|
||||||
// Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph.
|
|
||||||
// This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers.
|
|
||||||
// The correct way to use this API is to discard the deallocated tensors and create new ones.
|
|
||||||
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
|
|
||||||
|
|
||||||
// Set a callback to be called for each resulting node during graph compute
|
|
||||||
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Utils
|
|
||||||
//
|
|
||||||
|
|
||||||
struct ggml_backend_graph_copy {
|
|
||||||
ggml_backend_buffer_t buffer;
|
|
||||||
struct ggml_context * ctx_allocated;
|
|
||||||
struct ggml_context * ctx_unallocated;
|
|
||||||
struct ggml_cgraph * graph;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Copy a graph to a different backend
|
|
||||||
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
|
|
||||||
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
|
|
||||||
|
|
||||||
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
|
||||||
|
|
||||||
// Compare the output of two backends
|
|
||||||
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);
|
|
||||||
|
|
||||||
// Tensor initialization
|
|
||||||
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
|
||||||
GGML_API enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor);
|
|
||||||
|
|
||||||
// CPU buffer types are always available
|
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// backend API
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);
|
|
||||||
|
|
||||||
// number of threads used for conversion to float
|
|
||||||
// for openblas and blis, this will also set the number of threads used for blas operations
|
|
||||||
GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2023-2024 The ggml authors
|
|
||||||
*
|
|
||||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
* of this software and associated documentation files (the "Software"), to
|
|
||||||
* deal in the Software without restriction, including without limitation the
|
|
||||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
|
||||||
* sell copies of the Software, and to permit persons to whom the Software is
|
|
||||||
* furnished to do so, subject to the following conditions:
|
|
||||||
*
|
|
||||||
* The above copyright notice and this permission notice shall be included in
|
|
||||||
* all copies or substantial portions of the Software.
|
|
||||||
*
|
|
||||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
||||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
|
||||||
* IN THE SOFTWARE.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Maximum number of CANN devices supported.
|
|
||||||
*/
|
|
||||||
#define GGML_CANN_MAX_DEVICES 16
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Initializes the CANN backend for a specified device.
|
|
||||||
*
|
|
||||||
* This function initializes the CANN backend for the given device.
|
|
||||||
* It verifies the device index, allocates a context, and creates a backend
|
|
||||||
* instance.
|
|
||||||
*
|
|
||||||
* @param device The index of the device to initialize.
|
|
||||||
* @return A pointer to the initialized backend instance, or nullptr on failure.
|
|
||||||
*/
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Checks if a given backend is a CANN backend.
|
|
||||||
*
|
|
||||||
* This function verifies if the provided backend is a CANN backend by comparing
|
|
||||||
* its GUID with the CANN backend's GUID.
|
|
||||||
*
|
|
||||||
* @param backend The backend instance to check.
|
|
||||||
* @return True if the backend is a CANN backend, false otherwise.
|
|
||||||
*/
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Retrieves the CANN buffer type for a specified device.
|
|
||||||
*
|
|
||||||
* This function initializes and returns the buffer type interface associated
|
|
||||||
* with the given device. It ensures thread-safe access using a mutex.
|
|
||||||
*
|
|
||||||
* @param device The device index for which to retrieve the buffer type.
|
|
||||||
* @return A pointer to the buffer type interface for the specified device, or
|
|
||||||
* nullptr if the device index is out of range.
|
|
||||||
*/
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t
|
|
||||||
ggml_backend_cann_buffer_type(int32_t device);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Retrieves the number of CANN devices available.
|
|
||||||
*
|
|
||||||
* This function returns the number of CANN devices available based on
|
|
||||||
* information obtained from `ggml_cann_info()`.
|
|
||||||
*
|
|
||||||
* @return The number of CANN devices available.
|
|
||||||
*/
|
|
||||||
GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.
|
|
||||||
*
|
|
||||||
* @return A pointer to the host buffer type interface.
|
|
||||||
*/
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Retrieves the description of a specific CANN device.
|
|
||||||
*
|
|
||||||
* This function sets the specified device, retrieves the SoC name,
|
|
||||||
* and writes it into the provided description buffer.
|
|
||||||
*
|
|
||||||
* @param device The device index to retrieve the description for.
|
|
||||||
* @param description Pointer to a buffer where the description will be written.
|
|
||||||
* @param description_size Size of the description buffer.
|
|
||||||
*/
|
|
||||||
GGML_BACKEND_API void ggml_backend_cann_get_device_description(
|
|
||||||
int32_t device, char* description, size_t description_size);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Retrieves the memory information of a specific CANN device.
|
|
||||||
*
|
|
||||||
* This function sets the specified device, retrieves the free and total
|
|
||||||
* memory information of the specified type (ACL_HBM_MEM), and stores them
|
|
||||||
* in the provided pointers.
|
|
||||||
*
|
|
||||||
* @param device The device index to retrieve memory information for.
|
|
||||||
* @param free Pointer to a variable where the free memory size will be stored.
|
|
||||||
* @param total Pointer to a variable where the total memory size will be
|
|
||||||
* stored.
|
|
||||||
*/
|
|
||||||
GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device,
|
|
||||||
size_t* free,
|
|
||||||
size_t* total);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#ifndef __cplusplus
|
|
||||||
#error "This header is for C++ only"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-alloc.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
#include "gguf.h"
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
// Smart pointers for ggml types
|
|
||||||
|
|
||||||
// ggml
|
|
||||||
|
|
||||||
struct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } };
|
|
||||||
struct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } };
|
|
||||||
|
|
||||||
typedef std::unique_ptr<ggml_context, ggml_context_deleter> ggml_context_ptr;
|
|
||||||
typedef std::unique_ptr<gguf_context, gguf_context_deleter> gguf_context_ptr;
|
|
||||||
|
|
||||||
// ggml-alloc
|
|
||||||
|
|
||||||
struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };
|
|
||||||
|
|
||||||
typedef std::unique_ptr<ggml_gallocr, ggml_gallocr_deleter> ggml_gallocr_ptr;
|
|
||||||
|
|
||||||
// ggml-backend
|
|
||||||
|
|
||||||
struct ggml_backend_deleter { void operator()(ggml_backend_t backend) { ggml_backend_free(backend); } };
|
|
||||||
struct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } };
|
|
||||||
struct ggml_backend_event_deleter { void operator()(ggml_backend_event_t event) { ggml_backend_event_free(event); } };
|
|
||||||
struct ggml_backend_sched_deleter { void operator()(ggml_backend_sched_t sched) { ggml_backend_sched_free(sched); } };
|
|
||||||
|
|
||||||
typedef std::unique_ptr<ggml_backend, ggml_backend_deleter> ggml_backend_ptr;
|
|
||||||
typedef std::unique_ptr<ggml_backend_buffer, ggml_backend_buffer_deleter> ggml_backend_buffer_ptr;
|
|
||||||
typedef std::unique_ptr<ggml_backend_event, ggml_backend_event_deleter> ggml_backend_event_ptr;
|
|
||||||
typedef std::unique_ptr<ggml_backend_sched, ggml_backend_sched_deleter> ggml_backend_sched_ptr;
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// the compute plan that needs to be prepared for ggml_graph_compute()
|
|
||||||
// since https://github.com/ggml-org/ggml/issues/287
|
|
||||||
struct ggml_cplan {
|
|
||||||
size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
|
|
||||||
uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
|
|
||||||
|
|
||||||
int n_threads;
|
|
||||||
struct ggml_threadpool * threadpool;
|
|
||||||
|
|
||||||
// abort ggml_graph_compute when true
|
|
||||||
ggml_abort_callback abort_callback;
|
|
||||||
void * abort_callback_data;
|
|
||||||
};
|
|
||||||
|
|
||||||
// numa strategies
|
|
||||||
enum ggml_numa_strategy {
|
|
||||||
GGML_NUMA_STRATEGY_DISABLED = 0,
|
|
||||||
GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
|
|
||||||
GGML_NUMA_STRATEGY_ISOLATE = 2,
|
|
||||||
GGML_NUMA_STRATEGY_NUMACTL = 3,
|
|
||||||
GGML_NUMA_STRATEGY_MIRROR = 4,
|
|
||||||
GGML_NUMA_STRATEGY_COUNT
|
|
||||||
};
|
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
|
|
||||||
GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
|
|
||||||
|
|
||||||
GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
|
|
||||||
GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
|
|
||||||
|
|
||||||
GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
|
|
||||||
GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
|
|
||||||
|
|
||||||
GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
|
|
||||||
GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
|
|
||||||
|
|
||||||
GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
|
|
||||||
GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
|
|
||||||
|
|
||||||
GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
|
|
||||||
GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
|
|
||||||
|
|
||||||
GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
|
|
||||||
GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
|
|
||||||
|
|
||||||
GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
|
|
||||||
GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
|
|
||||||
GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool);
|
|
||||||
GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
|
|
||||||
GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
|
|
||||||
|
|
||||||
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
|
||||||
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
|
||||||
GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(
|
|
||||||
const struct ggml_cgraph * cgraph,
|
|
||||||
int n_threads, /* = GGML_DEFAULT_N_THREADS */
|
|
||||||
struct ggml_threadpool * threadpool /* = NULL */ );
|
|
||||||
GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
|
|
||||||
|
|
||||||
// same as ggml_graph_compute() but the work data is allocated as a part of the context
|
|
||||||
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
|
|
||||||
GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
|
|
||||||
|
|
||||||
//
|
|
||||||
// system info
|
|
||||||
//
|
|
||||||
|
|
||||||
// x86
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_sse3 (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_ssse3 (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_avx (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_avx2 (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_bmi2 (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_f16c (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_fma (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_avx512 (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void);
|
|
||||||
// ARM
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_neon (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_arm_fma (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_fp16_va (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_dotprod (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_sve (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_sme (void);
|
|
||||||
// other
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_get_rvv_vlen (void); // risc-v vector length in bytes
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_vxe (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
|
|
||||||
GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
|
|
||||||
|
|
||||||
// Internal types and functions exposed for tests and benchmarks
|
|
||||||
|
|
||||||
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
|
|
||||||
const void * GGML_RESTRICT y, size_t by, int nrc);
|
|
||||||
|
|
||||||
struct ggml_type_traits_cpu {
|
|
||||||
ggml_from_float_t from_float;
|
|
||||||
ggml_vec_dot_t vec_dot;
|
|
||||||
enum ggml_type vec_dot_type;
|
|
||||||
int64_t nrows; // number of rows to process simultaneously
|
|
||||||
};
|
|
||||||
|
|
||||||
GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
|
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_cpu_init(void);
|
|
||||||
|
|
||||||
//
|
|
||||||
// CPU backend
|
|
||||||
//
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend);
|
|
||||||
GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
|
|
||||||
GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
|
|
||||||
GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
|
|
||||||
GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
|
|
||||||
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
|
|
||||||
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
|
|
||||||
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
|
|
||||||
GGML_BACKEND_API void ggml_cpu_bf16_to_fp32(const ggml_bf16_t *, float *, int64_t);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef GGML_USE_HIP
|
|
||||||
#define GGML_CUDA_NAME "ROCm"
|
|
||||||
#define GGML_CUBLAS_NAME "hipBLAS"
|
|
||||||
#elif defined(GGML_USE_MUSA)
|
|
||||||
#define GGML_CUDA_NAME "MUSA"
|
|
||||||
#define GGML_CUBLAS_NAME "muBLAS"
|
|
||||||
#else
|
|
||||||
#define GGML_CUDA_NAME "CUDA"
|
|
||||||
#define GGML_CUBLAS_NAME "cuBLAS"
|
|
||||||
#endif
|
|
||||||
#define GGML_CUDA_MAX_DEVICES 16
|
|
||||||
|
|
||||||
// backend API
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
|
|
||||||
|
|
||||||
// device buffer
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
|
||||||
|
|
||||||
// split tensor buffer that splits matrices by rows across multiple devices
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
|
|
||||||
|
|
||||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void);
|
|
||||||
GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
|
|
||||||
GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
|
|
||||||
GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// backend API
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
// Note: this description is outdated
|
|
||||||
//
|
|
||||||
// An interface allowing to compute ggml_cgraph with Metal
|
|
||||||
//
|
|
||||||
// This is a fully functional interface that extends ggml with GPU support for Apple devices.
|
|
||||||
// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
|
|
||||||
//
|
|
||||||
// How it works?
|
|
||||||
//
|
|
||||||
// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
|
|
||||||
// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
|
|
||||||
// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
|
|
||||||
//
|
|
||||||
// You only need to make sure that all memory buffers that you used during the graph creation
|
|
||||||
// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
|
|
||||||
// used during the graph evaluation to determine the arguments of the compute kernels.
|
|
||||||
//
|
|
||||||
// Synchronization between device and host memory (for example for input and output tensors)
|
|
||||||
// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdbool.h>
|
|
||||||
|
|
||||||
struct ggml_tensor;
|
|
||||||
struct ggml_cgraph;
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//
|
|
||||||
// backend API
|
|
||||||
// user-code should use only these functions
|
|
||||||
//
|
|
||||||
|
|
||||||
// TODO: remove in the future
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
|
||||||
|
|
||||||
// helper to check if the device supports a specific family
|
|
||||||
// ideally, the user code should be doing these checks
|
|
||||||
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
||||||
GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
|
||||||
|
|
||||||
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
|
||||||
GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
#ifndef GGML_OPENCL_H
|
|
||||||
#define GGML_OPENCL_H
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//
|
|
||||||
// backend API
|
|
||||||
//
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void);
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // GGML_OPENCL_H
|
|
||||||
@@ -1,256 +0,0 @@
|
|||||||
// This file contains functionality for training models using GGML.
|
|
||||||
// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets.
|
|
||||||
// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code.
|
|
||||||
//
|
|
||||||
// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
struct ggml_opt_dataset;
|
|
||||||
struct ggml_opt_context;
|
|
||||||
struct ggml_opt_result;
|
|
||||||
|
|
||||||
typedef struct ggml_opt_dataset * ggml_opt_dataset_t;
|
|
||||||
typedef struct ggml_opt_context * ggml_opt_context_t;
|
|
||||||
typedef struct ggml_opt_result * ggml_opt_result_t;
|
|
||||||
|
|
||||||
// ====== Loss ======
|
|
||||||
|
|
||||||
// built-in loss types, i.e. the built-in quantities minimized by the optimizer
|
|
||||||
// custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value
|
|
||||||
enum ggml_opt_loss_type {
|
|
||||||
GGML_OPT_LOSS_TYPE_MEAN,
|
|
||||||
GGML_OPT_LOSS_TYPE_SUM,
|
|
||||||
GGML_OPT_LOSS_TYPE_CROSS_ENTROPY,
|
|
||||||
GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
|
|
||||||
};
|
|
||||||
|
|
||||||
// ====== Dataset ======
|
|
||||||
|
|
||||||
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
|
|
||||||
enum ggml_type type_data, // the type for the internal data tensor
|
|
||||||
enum ggml_type type_label, // the type for the internal labels tensor
|
|
||||||
int64_t ne_datapoint, // number of elements per datapoint
|
|
||||||
int64_t ne_label, // number of elements per label
|
|
||||||
int64_t ndata, // total number of datapoints/labels
|
|
||||||
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
|
|
||||||
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
|
|
||||||
|
|
||||||
// get underlying tensors that store the data
|
|
||||||
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
|
|
||||||
|
|
||||||
// shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative
|
|
||||||
GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata);
|
|
||||||
|
|
||||||
// get batch at position ibatch from dataset and copy the data to data_batch and labels_batch
|
|
||||||
GGML_API void ggml_opt_dataset_get_batch(
|
|
||||||
ggml_opt_dataset_t dataset,
|
|
||||||
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
|
|
||||||
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
|
|
||||||
int64_t ibatch);
|
|
||||||
GGML_API void ggml_opt_dataset_get_batch_host(
|
|
||||||
ggml_opt_dataset_t dataset,
|
|
||||||
void * data_batch,
|
|
||||||
size_t nb_data_batch,
|
|
||||||
void * labels_batch,
|
|
||||||
int64_t ibatch);
|
|
||||||
|
|
||||||
// ====== Model / Context ======
|
|
||||||
|
|
||||||
enum ggml_opt_build_type {
|
|
||||||
GGML_OPT_BUILD_TYPE_FORWARD = 10,
|
|
||||||
GGML_OPT_BUILD_TYPE_GRAD = 20,
|
|
||||||
GGML_OPT_BUILD_TYPE_OPT = 30,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum ggml_opt_optimizer_type {
|
|
||||||
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
|
||||||
GGML_OPT_OPTIMIZER_TYPE_SGD,
|
|
||||||
|
|
||||||
GGML_OPT_OPTIMIZER_TYPE_COUNT
|
|
||||||
};
|
|
||||||
|
|
||||||
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
|
||||||
struct ggml_opt_optimizer_params {
|
|
||||||
struct {
|
|
||||||
float alpha; // learning rate
|
|
||||||
float beta1; // first AdamW momentum
|
|
||||||
float beta2; // second AdamW momentum
|
|
||||||
float eps; // epsilon for numerical stability
|
|
||||||
float wd; // weight decay - 0.0f to disable
|
|
||||||
} adamw;
|
|
||||||
struct {
|
|
||||||
float alpha; // learning rate
|
|
||||||
float wd; // weight decay
|
|
||||||
} sgd;
|
|
||||||
};
|
|
||||||
|
|
||||||
// callback to calculate optimizer parameters prior to a backward pass
|
|
||||||
// userdata can be used to pass arbitrary data
|
|
||||||
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
|
|
||||||
|
|
||||||
// returns the default optimizer params (constant, hard-coded values)
|
|
||||||
// userdata is not used
|
|
||||||
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
|
|
||||||
|
|
||||||
// casts userdata to ggml_opt_optimizer_params and returns it
|
|
||||||
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
|
|
||||||
|
|
||||||
// parameters for initializing a new optimization context
|
|
||||||
struct ggml_opt_params {
|
|
||||||
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
|
|
||||||
|
|
||||||
// by default the forward graph needs to be reconstructed for each eval
|
|
||||||
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
|
|
||||||
struct ggml_context * ctx_compute;
|
|
||||||
struct ggml_tensor * inputs;
|
|
||||||
struct ggml_tensor * outputs;
|
|
||||||
|
|
||||||
enum ggml_opt_loss_type loss_type;
|
|
||||||
enum ggml_opt_build_type build_type;
|
|
||||||
|
|
||||||
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
|
||||||
|
|
||||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
|
||||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
|
||||||
|
|
||||||
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
|
||||||
enum ggml_opt_optimizer_type optimizer;
|
|
||||||
};
|
|
||||||
|
|
||||||
// get parameters for an optimization context with defaults set where possible
|
|
||||||
// parameters for which no sensible defaults exist are supplied as arguments to this function
|
|
||||||
GGML_API struct ggml_opt_params ggml_opt_default_params(
|
|
||||||
ggml_backend_sched_t backend_sched,
|
|
||||||
enum ggml_opt_loss_type loss_type);
|
|
||||||
|
|
||||||
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
|
|
||||||
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
|
|
||||||
|
|
||||||
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
|
||||||
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
|
||||||
|
|
||||||
GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
|
|
||||||
|
|
||||||
// get underlying tensors that store data
|
|
||||||
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
|
|
||||||
|
|
||||||
// get the gradient accumulator for a node from the forward graph
|
|
||||||
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
|
||||||
|
|
||||||
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
|
|
||||||
|
|
||||||
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
|
|
||||||
|
|
||||||
// ====== Optimization Result ======
|
|
||||||
|
|
||||||
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
|
||||||
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
|
|
||||||
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
|
|
||||||
|
|
||||||
// get data from result, uncertainties are optional and can be ignored by passing NULL
|
|
||||||
GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints
|
|
||||||
GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value
|
|
||||||
GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values
|
|
||||||
GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value
|
|
||||||
|
|
||||||
// ====== Computation ======
|
|
||||||
|
|
||||||
// if not using static graphs, this function must be called prior to ggml_opt_alloc
|
|
||||||
GGML_API void ggml_opt_prepare_alloc(
|
|
||||||
ggml_opt_context_t opt_ctx,
|
|
||||||
struct ggml_context * ctx_compute,
|
|
||||||
struct ggml_cgraph * gf,
|
|
||||||
struct ggml_tensor * inputs,
|
|
||||||
struct ggml_tensor * outputs);
|
|
||||||
|
|
||||||
// allocate the next graph for evaluation, either forward or forward + backward
|
|
||||||
// must be called exactly once prior to calling ggml_opt_eval
|
|
||||||
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
|
|
||||||
|
|
||||||
// do forward pass, increment result if not NULL, do backward pass if allocated
|
|
||||||
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
|
|
||||||
|
|
||||||
// ############################################################################
|
|
||||||
// ## The high-level functions start here. They do not depend on any private ##
|
|
||||||
// ## functions or structs and can be copied to and adapted for user code. ##
|
|
||||||
// ############################################################################
|
|
||||||
|
|
||||||
// ====== Intended Usage ======
|
|
||||||
//
|
|
||||||
// 1. Select the appropriate loss for your problem.
|
|
||||||
// 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them.
|
|
||||||
// Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster).
|
|
||||||
// 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors.
|
|
||||||
// The first context should contain the model parameters and inputs and be allocated statically in user code.
|
|
||||||
// The second context should contain all other tensors and will be (re)allocated automatically.
|
|
||||||
// Due to this automated allocation the data of the second context is not defined when accessed in user code.
|
|
||||||
// Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
|
|
||||||
// 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.
|
|
||||||
|
|
||||||
// signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
|
|
||||||
typedef void (*ggml_opt_epoch_callback)(
|
|
||||||
bool train, // true after training evaluation, false after validation evaluation
|
|
||||||
ggml_opt_context_t opt_ctx,
|
|
||||||
ggml_opt_dataset_t dataset,
|
|
||||||
ggml_opt_result_t result, // result associated with the dataset subsection
|
|
||||||
int64_t ibatch, // number of batches that have been evaluated so far
|
|
||||||
int64_t ibatch_max, // total number of batches in this dataset subsection
|
|
||||||
int64_t t_start_us); // time at which the evaluation on the dataset subsection was started
|
|
||||||
|
|
||||||
// do training on front of dataset, do evaluation only on back of dataset
|
|
||||||
GGML_API void ggml_opt_epoch(
|
|
||||||
ggml_opt_context_t opt_ctx,
|
|
||||||
ggml_opt_dataset_t dataset,
|
|
||||||
ggml_opt_result_t result_train, // result to increment during training, ignored if NULL
|
|
||||||
ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL
|
|
||||||
int64_t idata_split, // data index at which to split training and evaluation
|
|
||||||
ggml_opt_epoch_callback callback_train,
|
|
||||||
ggml_opt_epoch_callback callback_eval);
|
|
||||||
|
|
||||||
// callback that prints a progress bar on stderr
|
|
||||||
GGML_API void ggml_opt_epoch_callback_progress_bar(
|
|
||||||
bool train,
|
|
||||||
ggml_opt_context_t opt_ctx,
|
|
||||||
ggml_opt_dataset_t dataset,
|
|
||||||
ggml_opt_result_t result,
|
|
||||||
int64_t ibatch,
|
|
||||||
int64_t ibatch_max,
|
|
||||||
int64_t t_start_us);
|
|
||||||
|
|
||||||
// fit model defined by inputs and outputs to dataset
|
|
||||||
GGML_API void ggml_opt_fit(
|
|
||||||
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
|
|
||||||
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
|
|
||||||
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
|
|
||||||
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
|
||||||
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
|
||||||
enum ggml_opt_loss_type loss_type, // loss to minimize
|
|
||||||
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
|
|
||||||
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
|
||||||
int64_t nepoch, // how many times the dataset should be iterated over
|
|
||||||
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
|
||||||
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
|
||||||
bool silent); // whether or not info prints to stderr should be suppressed
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define RPC_PROTO_MAJOR_VERSION 3
|
|
||||||
#define RPC_PROTO_MINOR_VERSION 6
|
|
||||||
#define RPC_PROTO_PATCH_VERSION 0
|
|
||||||
#define GGML_RPC_MAX_SERVERS 16
|
|
||||||
|
|
||||||
// backend API
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);
|
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
|
||||||
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
//
|
|
||||||
// MIT license
|
|
||||||
// Copyright (C) 2024 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#define GGML_SYCL_NAME "SYCL"
|
|
||||||
#define GGML_SYCL_MAX_DEVICES 48
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// backend API
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend);
|
|
||||||
|
|
||||||
// devide buffer
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
|
|
||||||
|
|
||||||
// split tensor buffer that splits matrices by rows across multiple devices
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
|
|
||||||
|
|
||||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void);
|
|
||||||
GGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
|
|
||||||
GGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device,
|
|
||||||
char *description,
|
|
||||||
size_t description_size);
|
|
||||||
GGML_BACKEND_API int ggml_backend_sycl_get_device_count();
|
|
||||||
GGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
|
|
||||||
|
|
||||||
// SYCL doesn't support registering host memory, keep here for reference
|
|
||||||
// GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
|
|
||||||
// GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define GGML_VK_NAME "Vulkan"
|
|
||||||
#define GGML_VK_MAX_DEVICES 16
|
|
||||||
|
|
||||||
// backend API
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend);
|
|
||||||
GGML_BACKEND_API int ggml_backend_vk_get_device_count(void);
|
|
||||||
GGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
|
|
||||||
GGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
|
|
||||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define GGML_WEBGPU_NAME "WebGPU"
|
|
||||||
|
|
||||||
// Needed for examples in ggml
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// device buffer
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_type(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// backend API
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_zendnn_init(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_zendnn(ggml_backend_t backend);
|
|
||||||
|
|
||||||
// number of threads used for zendnn operations
|
|
||||||
GGML_BACKEND_API void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zendnn_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,202 +0,0 @@
|
|||||||
// This file contains functionality related to "GGUF" files, the binary file format used by ggml.
|
|
||||||
// GGUF files have the following structure:
|
|
||||||
//
|
|
||||||
// 1. File magic "GGUF" (4 bytes).
|
|
||||||
// 2. File version (uint32_t).
|
|
||||||
// 3. Number of ggml tensors in file (int64_t).
|
|
||||||
// 4. Number of key-value-pairs in file (int64_t).
|
|
||||||
// 5. For each KV pair:
|
|
||||||
// 1. The key (string).
|
|
||||||
// 2. The value type (gguf_type).
|
|
||||||
// 3a. If the value type is GGUF_TYPE_ARRAY:
|
|
||||||
// 1. The type of the array (gguf_type).
|
|
||||||
// 2. The number of elements in the array (uint64_t).
|
|
||||||
// 3. The binary representation of each element in the array.
|
|
||||||
// 3b. Otherwise:
|
|
||||||
// 1. The binary representation of the value.
|
|
||||||
// 6. For each ggml tensor:
|
|
||||||
// 1. The tensor name (string).
|
|
||||||
// 2. The number of dimensions of the tensor (uint32_t).
|
|
||||||
// 3. For each dimension:
|
|
||||||
// 1. The size of the tensor in the dimension (int64_t).
|
|
||||||
// 4. The tensor data type (ggml_type).
|
|
||||||
// 5. The tensor data offset in the tensor data binary blob (uint64_t).
|
|
||||||
// 7. The tensor data binary blob (optional, aligned).
|
|
||||||
//
|
|
||||||
// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator.
|
|
||||||
// All enums are stored as int32_t.
|
|
||||||
// All bool values are stored as int8_t.
|
|
||||||
// If the special key "general.alignment" (uint32_t) is defined it is used for alignment,
|
|
||||||
// otherwise GGUF_DEFAULT_ALIGNMENT is used.
|
|
||||||
//
|
|
||||||
// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
#include <stdbool.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#define GGUF_MAGIC "GGUF"
|
|
||||||
#define GGUF_VERSION 3
|
|
||||||
|
|
||||||
#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment"
|
|
||||||
|
|
||||||
#define GGUF_DEFAULT_ALIGNMENT 32
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// types that can be stored as GGUF KV data
|
|
||||||
enum gguf_type {
|
|
||||||
GGUF_TYPE_UINT8 = 0,
|
|
||||||
GGUF_TYPE_INT8 = 1,
|
|
||||||
GGUF_TYPE_UINT16 = 2,
|
|
||||||
GGUF_TYPE_INT16 = 3,
|
|
||||||
GGUF_TYPE_UINT32 = 4,
|
|
||||||
GGUF_TYPE_INT32 = 5,
|
|
||||||
GGUF_TYPE_FLOAT32 = 6,
|
|
||||||
GGUF_TYPE_BOOL = 7,
|
|
||||||
GGUF_TYPE_STRING = 8,
|
|
||||||
GGUF_TYPE_ARRAY = 9,
|
|
||||||
GGUF_TYPE_UINT64 = 10,
|
|
||||||
GGUF_TYPE_INT64 = 11,
|
|
||||||
GGUF_TYPE_FLOAT64 = 12,
|
|
||||||
GGUF_TYPE_COUNT, // marks the end of the enum
|
|
||||||
};
|
|
||||||
|
|
||||||
struct gguf_context;
|
|
||||||
|
|
||||||
struct gguf_init_params {
|
|
||||||
bool no_alloc;
|
|
||||||
|
|
||||||
// if not NULL, create a ggml_context and allocate the tensor data in it
|
|
||||||
struct ggml_context ** ctx;
|
|
||||||
};
|
|
||||||
|
|
||||||
GGML_API struct gguf_context * gguf_init_empty(void);
|
|
||||||
GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
|
|
||||||
//GGML_API struct gguf_context * gguf_init_from_buffer(..);
|
|
||||||
|
|
||||||
GGML_API void gguf_free(struct gguf_context * ctx);
|
|
||||||
|
|
||||||
GGML_API const char * gguf_type_name(enum gguf_type type);
|
|
||||||
|
|
||||||
GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx);
|
|
||||||
GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
|
|
||||||
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
|
|
||||||
|
|
||||||
GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx);
|
|
||||||
GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found
|
|
||||||
GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
|
|
||||||
GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
|
|
||||||
// will abort if the wrong type is used for the key
|
|
||||||
GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
GGML_API size_t gguf_get_arr_n (const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
|
|
||||||
// get raw pointer to the first element of the array with the given key_id
|
|
||||||
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
|
|
||||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
|
|
||||||
// get ith C string from array with given key_id
|
|
||||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
|
||||||
|
|
||||||
GGML_API int64_t gguf_get_n_tensors (const struct gguf_context * ctx);
|
|
||||||
GGML_API int64_t gguf_find_tensor (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found
|
|
||||||
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id);
|
|
||||||
GGML_API const char * gguf_get_tensor_name (const struct gguf_context * ctx, int64_t tensor_id);
|
|
||||||
GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int64_t tensor_id);
|
|
||||||
GGML_API size_t gguf_get_tensor_size (const struct gguf_context * ctx, int64_t tensor_id);
|
|
||||||
|
|
||||||
// removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist)
|
|
||||||
GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key);
|
|
||||||
|
|
||||||
// overrides an existing KV pair or adds a new one, the new KV pair is always at the back
|
|
||||||
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
|
||||||
GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
|
|
||||||
GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
|
|
||||||
GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
|
|
||||||
GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
|
|
||||||
GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
|
|
||||||
GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
|
|
||||||
GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
|
|
||||||
GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
|
|
||||||
GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
|
|
||||||
GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
|
|
||||||
GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
|
|
||||||
|
|
||||||
// creates a new array with n elements of the given type and copies the corresponding number of bytes from data
|
|
||||||
GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n);
|
|
||||||
|
|
||||||
// creates a new array with n strings and copies the corresponding strings from data
|
|
||||||
GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n);
|
|
||||||
|
|
||||||
// set or add KV pairs from another context
|
|
||||||
GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src);
|
|
||||||
|
|
||||||
// add tensor to GGUF context, tensor name must be unique
|
|
||||||
GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
|
|
||||||
|
|
||||||
// after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated
|
|
||||||
// in such a way that the tensor data remains as one contiguous block (except for padding)
|
|
||||||
GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
|
|
||||||
|
|
||||||
// assumes that at least gguf_get_tensor_size bytes can be read from data
|
|
||||||
GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data);
|
|
||||||
|
|
||||||
// writing gguf files can be done in 3 ways:
|
|
||||||
//
|
|
||||||
// - write the entire gguf_context to a binary file in a single pass:
|
|
||||||
//
|
|
||||||
// gguf_write_to_file(ctx, fname, /*only_meta =*/ false);
|
|
||||||
//
|
|
||||||
// - write only the meta data to a file, then re-open the file and append the tensor data:
|
|
||||||
//
|
|
||||||
// gguf_write_to_file(ctx, fname, /*only_meta =*/ true);
|
|
||||||
// FILE * f = fopen(fname, "ab");
|
|
||||||
// fwrite(f, ...); // write tensor data
|
|
||||||
// fclose(f);
|
|
||||||
//
|
|
||||||
// - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
|
|
||||||
//
|
|
||||||
// FILE * f = fopen(fname, "wb");
|
|
||||||
// const size_t size_meta = gguf_get_meta_size(ctx);
|
|
||||||
// fseek(f, size_meta, SEEK_SET);
|
|
||||||
// fwrite(f, ...); // write tensor data
|
|
||||||
// void * data = malloc(size_meta);
|
|
||||||
// gguf_get_meta_data(ctx, data);
|
|
||||||
// rewind(f);
|
|
||||||
// fwrite(data, 1, data, f);
|
|
||||||
// free(data);
|
|
||||||
// fclose(f);
|
|
||||||
//
|
|
||||||
|
|
||||||
// write the entire context to a binary file
|
|
||||||
GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
|
|
||||||
|
|
||||||
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
|
|
||||||
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
|
|
||||||
|
|
||||||
// writes the meta data to pointer "data"
|
|
||||||
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,490 +0,0 @@
|
|||||||
include(CheckCXXCompilerFlag)
|
|
||||||
include("../cmake/common.cmake")
|
|
||||||
|
|
||||||
add_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES})
|
|
||||||
|
|
||||||
# enable libstdc++ assertions for debug builds
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|
||||||
add_compile_definitions($<$<CONFIG:Debug>:_GLIBCXX_ASSERTIONS>)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NOT MSVC)
|
|
||||||
if (GGML_SANITIZE_THREAD)
|
|
||||||
add_compile_options(-fsanitize=thread)
|
|
||||||
link_libraries (-fsanitize=thread)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_SANITIZE_ADDRESS)
|
|
||||||
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
|
||||||
link_libraries (-fsanitize=address)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_SANITIZE_UNDEFINED)
|
|
||||||
add_compile_options(-fsanitize=undefined)
|
|
||||||
link_libraries (-fsanitize=undefined)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_FATAL_WARNINGS)
|
|
||||||
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
|
||||||
list(APPEND C_FLAGS -Werror)
|
|
||||||
list(APPEND CXX_FLAGS -Werror)
|
|
||||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
|
||||||
add_compile_options(/WX)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_ALL_WARNINGS)
|
|
||||||
if (NOT MSVC)
|
|
||||||
list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
|
|
||||||
list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
|
|
||||||
-Werror=implicit-int -Werror=implicit-function-declaration)
|
|
||||||
list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
|
|
||||||
|
|
||||||
list(APPEND C_FLAGS ${WARNING_FLAGS})
|
|
||||||
list(APPEND CXX_FLAGS ${WARNING_FLAGS})
|
|
||||||
|
|
||||||
ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
|
|
||||||
|
|
||||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${C_FLAGS};${GF_C_FLAGS}>"
|
|
||||||
"$<$<COMPILE_LANGUAGE:CXX>:${CXX_FLAGS};${GF_CXX_FLAGS}>")
|
|
||||||
else()
|
|
||||||
# todo : msvc
|
|
||||||
set(C_FLAGS "")
|
|
||||||
set(CXX_FLAGS "")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_LTO)
|
|
||||||
include(CheckIPOSupported)
|
|
||||||
check_ipo_supported(RESULT result OUTPUT output)
|
|
||||||
if (result)
|
|
||||||
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
|
|
||||||
else()
|
|
||||||
message(WARNING "IPO is not supported: ${output}")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER)
|
|
||||||
find_program(GGML_CCACHE_FOUND ccache)
|
|
||||||
find_program(GGML_SCCACHE_FOUND sccache)
|
|
||||||
|
|
||||||
if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND)
|
|
||||||
if(GGML_CCACHE_FOUND)
|
|
||||||
set(GGML_CCACHE_VARIANT ccache)
|
|
||||||
else()
|
|
||||||
set(GGML_CCACHE_VARIANT sccache)
|
|
||||||
endif()
|
|
||||||
# TODO: should not be set globally
|
|
||||||
if (GGML_SYCL AND GGML_CCACHE_FOUND AND WIN32)
|
|
||||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache compiler_type=icl")
|
|
||||||
else ()
|
|
||||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
|
|
||||||
endif ()
|
|
||||||
set(ENV{CCACHE_SLOPPINESS} time_macros)
|
|
||||||
message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
|
|
||||||
else()
|
|
||||||
message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
|
|
||||||
endif ()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# this version of Apple ld64 is buggy
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
|
|
||||||
ERROR_VARIABLE output
|
|
||||||
OUTPUT_QUIET
|
|
||||||
)
|
|
||||||
|
|
||||||
if (output MATCHES "dyld-1015\.7")
|
|
||||||
add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# architecture specific
|
|
||||||
# TODO: probably these flags need to be tweaked on some architectures
|
|
||||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
|
||||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
|
||||||
if (MSVC)
|
|
||||||
string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
|
|
||||||
message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
|
|
||||||
else ()
|
|
||||||
set(CMAKE_GENERATOR_PLATFORM_LWR "")
|
|
||||||
endif ()
|
|
||||||
ggml_get_system_arch()
|
|
||||||
message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}")
|
|
||||||
|
|
||||||
if (NOT MSVC)
|
|
||||||
if (GGML_STATIC)
|
|
||||||
if (UNIX AND NOT APPLE)
|
|
||||||
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a;.so")
|
|
||||||
endif()
|
|
||||||
add_link_options(-static)
|
|
||||||
if (MINGW)
|
|
||||||
add_link_options(-static-libgcc -static-libstdc++)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
if (GGML_GPROF)
|
|
||||||
add_compile_options(-pg)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
|
||||||
# POSIX conformance
|
|
||||||
#
|
|
||||||
|
|
||||||
# clock_gettime came in POSIX.1b (1993)
|
|
||||||
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
|
|
||||||
# posix_memalign came in POSIX.1-2001 / SUSv3
|
|
||||||
# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
|
|
||||||
|
|
||||||
# Somehow in OpenBSD whenever POSIX conformance is specified
|
|
||||||
# some string functions rely on locale_t availability,
|
|
||||||
# which was introduced in POSIX.1-2008, forcing us to go higher
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
|
||||||
add_compile_definitions(_XOPEN_SOURCE=700)
|
|
||||||
elseif (CMAKE_SYSTEM_NAME MATCHES "AIX")
|
|
||||||
# Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default,
|
|
||||||
# in order to define _SC_PHYS_PAGES.
|
|
||||||
else()
|
|
||||||
add_compile_definitions(_XOPEN_SOURCE=600)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Data types, macros and functions related to controlling CPU affinity and
|
|
||||||
# some memory allocation are available on Linux through GNU extensions in libc
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Android")
|
|
||||||
add_compile_definitions(_GNU_SOURCE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
|
|
||||||
# and on macOS its availability depends on enabling Darwin extensions
|
|
||||||
# similarly on DragonFly, enabling BSD extensions is necessary
|
|
||||||
if (
|
|
||||||
CMAKE_SYSTEM_NAME MATCHES "Darwin" OR
|
|
||||||
CMAKE_SYSTEM_NAME MATCHES "iOS" OR
|
|
||||||
CMAKE_SYSTEM_NAME MATCHES "tvOS" OR
|
|
||||||
CMAKE_SYSTEM_NAME MATCHES "DragonFly"
|
|
||||||
)
|
|
||||||
add_compile_definitions(_DARWIN_C_SOURCE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# alloca is a non-standard interface that is not visible on BSDs when
|
|
||||||
# POSIX conformance is specified, but not all of them provide a clean way
|
|
||||||
# to enable it in such cases
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
|
|
||||||
add_compile_definitions(__BSD_VISIBLE)
|
|
||||||
endif()
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
|
|
||||||
add_compile_definitions(_NETBSD_SOURCE)
|
|
||||||
endif()
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
|
||||||
add_compile_definitions(_BSD_SOURCE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (WIN32)
|
|
||||||
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# ggml
|
|
||||||
|
|
||||||
if (GGML_BACKEND_DL AND NOT BUILD_SHARED_LIBS)
|
|
||||||
message(FATAL_ERROR "GGML_BACKEND_DL requires BUILD_SHARED_LIBS")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_library(ggml-base
|
|
||||||
../include/ggml.h
|
|
||||||
../include/ggml-alloc.h
|
|
||||||
../include/ggml-backend.h
|
|
||||||
../include/ggml-cpp.h
|
|
||||||
../include/ggml-opt.h
|
|
||||||
../include/gguf.h
|
|
||||||
ggml.c
|
|
||||||
ggml.cpp
|
|
||||||
ggml-alloc.c
|
|
||||||
ggml-backend.cpp
|
|
||||||
ggml-opt.cpp
|
|
||||||
ggml-threading.cpp
|
|
||||||
ggml-threading.h
|
|
||||||
ggml-quants.c
|
|
||||||
ggml-quants.h
|
|
||||||
gguf.cpp)
|
|
||||||
|
|
||||||
set_target_properties(ggml-base PROPERTIES
|
|
||||||
VERSION ${GGML_VERSION}
|
|
||||||
SOVERSION ${GGML_VERSION_MAJOR}
|
|
||||||
)
|
|
||||||
|
|
||||||
target_include_directories(ggml-base PRIVATE .)
|
|
||||||
if (GGML_BACKEND_DL)
|
|
||||||
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_SCHED_NO_REALLOC)
|
|
||||||
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_library(ggml
|
|
||||||
ggml-backend-reg.cpp)
|
|
||||||
add_library(ggml::ggml ALIAS ggml)
|
|
||||||
|
|
||||||
set_target_properties(ggml PROPERTIES
|
|
||||||
VERSION ${GGML_VERSION}
|
|
||||||
SOVERSION ${GGML_VERSION_MAJOR}
|
|
||||||
)
|
|
||||||
|
|
||||||
if (GGML_BACKEND_DIR)
|
|
||||||
if (NOT GGML_BACKEND_DL)
|
|
||||||
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
|
|
||||||
endif()
|
|
||||||
target_compile_definitions(ggml PUBLIC GGML_BACKEND_DIR="${GGML_BACKEND_DIR}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
target_link_libraries(ggml PUBLIC ggml-base)
|
|
||||||
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|
||||||
target_link_libraries(ggml PRIVATE dl)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
function(ggml_add_backend_library backend)
|
|
||||||
if (GGML_BACKEND_DL)
|
|
||||||
add_library(${backend} MODULE ${ARGN})
|
|
||||||
# write the shared library to the output directory
|
|
||||||
set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
|
||||||
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)
|
|
||||||
add_dependencies(ggml ${backend})
|
|
||||||
if (GGML_BACKEND_DIR)
|
|
||||||
install(TARGETS ${backend} LIBRARY DESTINATION ${GGML_BACKEND_DIR})
|
|
||||||
else()
|
|
||||||
install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
add_library(${backend} ${ARGN})
|
|
||||||
target_link_libraries(ggml PUBLIC ${backend})
|
|
||||||
install(TARGETS ${backend} LIBRARY)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
target_link_libraries(${backend} PRIVATE ggml-base)
|
|
||||||
target_include_directories(${backend} PRIVATE ..)
|
|
||||||
|
|
||||||
if (${BUILD_SHARED_LIBS})
|
|
||||||
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)
|
|
||||||
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Set versioning properties for all backend libraries
|
|
||||||
# Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782)
|
|
||||||
if (NOT (APPLE AND GGML_BACKEND_DL))
|
|
||||||
set_target_properties(${backend} PROPERTIES
|
|
||||||
VERSION ${GGML_VERSION}
|
|
||||||
SOVERSION ${GGML_VERSION_MAJOR}
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
|
||||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
|
||||||
CACHE INTERNAL "List of backends for cmake package")
|
|
||||||
else()
|
|
||||||
list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend)
|
|
||||||
if(has_backend EQUAL -1)
|
|
||||||
set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}"
|
|
||||||
CACHE INTERNAL "List of backends for cmake package")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
function(ggml_add_backend backend)
|
|
||||||
string(TOUPPER "GGML_${backend}" backend_id)
|
|
||||||
if (${backend_id})
|
|
||||||
string(TOLOWER "ggml-${backend}" backend_target)
|
|
||||||
add_subdirectory(${backend_target})
|
|
||||||
message(STATUS "Including ${backend} backend")
|
|
||||||
if (NOT GGML_BACKEND_DL)
|
|
||||||
string(TOUPPER "GGML_USE_${backend}" backend_use)
|
|
||||||
target_compile_definitions(ggml PUBLIC ${backend_use})
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
function(ggml_add_cpu_backend_variant tag_name)
|
|
||||||
set(GGML_CPU_TAG_NAME ${tag_name})
|
|
||||||
# other: OPENMP LLAMAFILE CPU_HBM
|
|
||||||
if (GGML_SYSTEM_ARCH STREQUAL "x86")
|
|
||||||
foreach (feat NATIVE
|
|
||||||
SSE42
|
|
||||||
AVX AVX2 BMI2 AVX_VNNI FMA F16C
|
|
||||||
AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
|
|
||||||
AMX_TILE AMX_INT8 AMX_BF16)
|
|
||||||
set(GGML_${feat} OFF)
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
foreach (feat ${ARGN})
|
|
||||||
set(GGML_${feat} ON)
|
|
||||||
endforeach()
|
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "ARM")
|
|
||||||
foreach (feat ${ARGN})
|
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
|
||||||
endforeach()
|
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
|
|
||||||
foreach (feat ${ARGN})
|
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
|
||||||
endforeach()
|
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
|
||||||
foreach (feat VXE2 NNPA)
|
|
||||||
set(GGML_INTERNAL_${feat} OFF)
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
foreach (feat ${ARGN})
|
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
|
||||||
endforeach()
|
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
|
||||||
foreach (feat RVV)
|
|
||||||
set(GGML_INTERNAL_${feat} OFF)
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
foreach (feat ${ARGN})
|
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
|
||||||
endforeach()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
ggml_add_backend(CPU)
|
|
||||||
|
|
||||||
if (GGML_CPU_ALL_VARIANTS)
|
|
||||||
if (NOT GGML_BACKEND_DL)
|
|
||||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
|
||||||
elseif (GGML_CPU_ARM_ARCH)
|
|
||||||
message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS")
|
|
||||||
endif()
|
|
||||||
if (GGML_SYSTEM_ARCH STREQUAL "x86")
|
|
||||||
ggml_add_cpu_backend_variant(x64)
|
|
||||||
ggml_add_cpu_backend_variant(sse42 SSE42)
|
|
||||||
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
|
|
||||||
if (NOT MSVC)
|
|
||||||
# __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
|
||||||
ggml_add_cpu_backend_variant(ivybridge SSE42 AVX F16C)
|
|
||||||
ggml_add_cpu_backend_variant(piledriver SSE42 AVX F16C FMA)
|
|
||||||
endif()
|
|
||||||
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C FMA AVX2 BMI2)
|
|
||||||
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C FMA AVX2 BMI2 AVX512)
|
|
||||||
ggml_add_cpu_backend_variant(cannonlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI)
|
|
||||||
ggml_add_cpu_backend_variant(cascadelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI)
|
|
||||||
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI)
|
|
||||||
if (NOT MSVC)
|
|
||||||
# MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
|
|
||||||
# https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
|
|
||||||
# https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
|
|
||||||
ggml_add_cpu_backend_variant(cooperlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16)
|
|
||||||
ggml_add_cpu_backend_variant(zen4 SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)
|
|
||||||
endif()
|
|
||||||
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)
|
|
||||||
if (NOT MSVC)
|
|
||||||
# MSVC doesn't support AMX
|
|
||||||
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
|
||||||
endif()
|
|
||||||
elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|
||||||
# Many of these features are optional so we build versions with popular
|
|
||||||
# combinations and name the backends based on the version they were
|
|
||||||
# first released with
|
|
||||||
ggml_add_cpu_backend_variant(armv8.0_1)
|
|
||||||
ggml_add_cpu_backend_variant(armv8.2_1 DOTPROD)
|
|
||||||
ggml_add_cpu_backend_variant(armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
|
||||||
ggml_add_cpu_backend_variant(armv8.2_3 DOTPROD FP16_VECTOR_ARITHMETIC SVE)
|
|
||||||
ggml_add_cpu_backend_variant(armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8)
|
|
||||||
ggml_add_cpu_backend_variant(armv8.6_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2)
|
|
||||||
ggml_add_cpu_backend_variant(armv9.2_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SME)
|
|
||||||
ggml_add_cpu_backend_variant(armv9.2_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2 SME)
|
|
||||||
elseif (CMAKE_SYSTEM_NAME MATCHES "Android")
|
|
||||||
# Android-specific backends with SoC-compatible feature sets
|
|
||||||
ggml_add_cpu_backend_variant(android_armv8.0_1)
|
|
||||||
ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
|
|
||||||
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
|
||||||
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
|
|
||||||
ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
|
|
||||||
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
|
|
||||||
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME)
|
|
||||||
elseif (APPLE)
|
|
||||||
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
|
|
||||||
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)
|
|
||||||
ggml_add_cpu_backend_variant(apple_m4 DOTPROD MATMUL_INT8 NOSVE SME)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}")
|
|
||||||
endif()
|
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|
||||||
ggml_add_cpu_backend_variant(power0)
|
|
||||||
ggml_add_cpu_backend_variant(power7_1 POWER7)
|
|
||||||
ggml_add_cpu_backend_variant(power7_2 POWER7 VSX)
|
|
||||||
ggml_add_cpu_backend_variant(power8_1 POWER8)
|
|
||||||
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
|
|
||||||
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
|
|
||||||
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
|
|
||||||
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
|
||||||
endif()
|
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|
||||||
ggml_add_cpu_backend_variant(z15 Z15 VXE2)
|
|
||||||
ggml_add_cpu_backend_variant(z16 Z16 VXE2 NNPA)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
|
||||||
endif()
|
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|
||||||
ggml_add_cpu_backend_variant(riscv64_0)
|
|
||||||
ggml_add_cpu_backend_variant(riscv64_v RVV)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
|
||||||
endif()
|
|
||||||
elseif (GGML_CPU)
|
|
||||||
ggml_add_cpu_backend_variant_impl("")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
ggml_add_backend(BLAS)
|
|
||||||
ggml_add_backend(CANN)
|
|
||||||
ggml_add_backend(CUDA)
|
|
||||||
ggml_add_backend(HIP)
|
|
||||||
ggml_add_backend(METAL)
|
|
||||||
ggml_add_backend(MUSA)
|
|
||||||
ggml_add_backend(RPC)
|
|
||||||
ggml_add_backend(SYCL)
|
|
||||||
ggml_add_backend(Vulkan)
|
|
||||||
ggml_add_backend(WebGPU)
|
|
||||||
ggml_add_backend(zDNN)
|
|
||||||
ggml_add_backend(OpenCL)
|
|
||||||
ggml_add_backend(Hexagon)
|
|
||||||
ggml_add_backend(ZenDNN)
|
|
||||||
|
|
||||||
foreach (target ggml-base ggml)
|
|
||||||
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
|
||||||
target_compile_features (${target} PRIVATE c_std_11 cxx_std_17) # don't bump
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
target_link_libraries(ggml-base PRIVATE Threads::Threads)
|
|
||||||
|
|
||||||
find_library(MATH_LIBRARY m)
|
|
||||||
if (MATH_LIBRARY)
|
|
||||||
if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT})
|
|
||||||
target_link_libraries(ggml-base PRIVATE m)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Android")
|
|
||||||
target_link_libraries(ggml-base PRIVATE dl)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_NAME MATCHES "visionOS")
|
|
||||||
target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
|
||||||
foreach (target ggml-base ggml)
|
|
||||||
set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
|
||||||
target_compile_definitions(${target} PRIVATE GGML_BUILD)
|
|
||||||
target_compile_definitions(${target} PUBLIC GGML_SHARED)
|
|
||||||
endforeach()
|
|
||||||
endif()
|
|
||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user