Inhalt
Aktueller Ordner:
ARS_ExplainableAIARS5_Rust.rs
// ARS 5.0 - Rust with Candle
// ============================================================================
use candle_core::{Device, Tensor, DType};
use candle_nn::{self as nn, VarBuilder, VarMap, Optimizer, AdamW, Linear, LinearConfig, Dropout, Module};
use std::collections::{HashMap, HashSet};
use rand::prelude::*;
// ============================================================================
// 1. SYMBOLIC COMPONENT
// ============================================================================
#[derive(Debug, Clone)]
pub struct ARSGrammar {
symbols: Vec<String>,
symbol_to_idx: HashMap<String, usize>,
n_symbols: usize,
counts: Vec<Vec<usize>>,
probs: Vec<Vec<f64>>,
constitutive_rules: HashSet<(usize, usize)>,
}
impl ARSGrammar {
pub fn new() -> Self {
let symbols = vec![
"KBG".to_string(), "VBG".to_string(), "KBBd".to_string(), "VBBd".to_string(),
"KBA".to_string(), "VBA".to_string(), "KAE".to_string(), "VAE".to_string(),
"KAA".to_string(), "VAA".to_string(), "KAV".to_string(), "VAV".to_string()
];
let symbol_to_idx: HashMap<String, usize> = symbols.iter()
.enumerate()
.map(|(i, s)| (s.clone(), i))
.collect();
let n = symbols.len();
let counts = vec![vec![0; n]; n];
let probs = vec![vec![1.0 / n as f64; n]; n];
// Constitutive rules
let mut rules = HashSet::new();
rules.insert((symbol_to_idx["KBG"], symbol_to_idx["VBG"]));
rules.insert((symbol_to_idx["KAV"], symbol_to_idx["VAV"]));
rules.insert((symbol_to_idx["VAV"], symbol_to_idx["KAV"]));
Self {
symbols, symbol_to_idx, n_symbols: n,
counts, probs, constitutive_rules: rules
}
}
pub fn update_probabilities(&mut self, from: usize, to: usize) {
self.counts[from][to] += 1;
let row_sum: usize = self.counts[from].iter().sum();
if row_sum > 0 {
for j in 0..self.n_symbols {
self.probs[from][j] = self.counts[from][j] as f64 / row_sum as f64;
}
}
}
pub fn is_valid_transition(&self, from: usize, to: usize) -> bool {
!self.constitutive_rules.contains(&(from, to))
}
pub fn get_prob(&self, from: usize, to: usize) -> f64 {
self.probs[from][to]
}
pub fn symbol_to_idx(&self, sym: &str) -> Option<usize> {
self.symbol_to_idx.get(sym).copied()
}
pub fn idx_to_symbol(&self, idx: usize) -> &str {
&self.symbols[idx]
}
}
// ============================================================================
// 2. NEURAL COMPONENT (Candle)
// ============================================================================
#[derive(Debug)]
pub struct ARSNeuralNetwork {
fc1: Linear,
fc2: Linear,
fc3: Linear,
dropout: f64,
n_symbols: usize,
device: Device,
}
impl ARSNeuralNetwork {
pub fn new(vs: VarBuilder, n_symbols: usize, hidden: usize, dropout: f64) -> Result<Self, candle_core::Error> {
let fc1 = nn::linear(n_symbols, hidden, LinearConfig::default(), vs.pp("fc1"))?;
let fc2 = nn::linear(hidden, hidden / 2, LinearConfig::default(), vs.pp("fc2"))?;
let fc3 = nn::linear(hidden / 2, n_symbols, LinearConfig::default(), vs.pp("fc3"))?;
Ok(Self {
fc1, fc2, fc3, dropout,
n_symbols, device: vs.device().clone(),
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor, candle_core::Error> {
let x = x.apply(&self.fc1)?.relu()?;
let x = nn::ops::dropout(&x, self.dropout)?;
let x = x.apply(&self.fc2)?.relu()?;
let x = nn::ops::dropout(&x, self.dropout)?;
let x = x.apply(&self.fc3)?;
nn::ops::softmax(&x, 1)
}
pub fn predict(&self, from_idx: usize) -> Result<Vec<f64>, candle_core::Error> {
let mut data = vec![0.0f32; self.n_symbols];
data[from_idx] = 1.0;
let input = Tensor::from_slice(&data, &[1, self.n_symbols], &self.device)?;
let output = self.forward(&input)?;
let probs = output.to_vec2::<f32>()?;
Ok(probs[0].iter().map(|&x| x as f64).collect())
}
}
// ============================================================================
// 3. HYBRID NEURO-SYMBOLIC SYSTEM
// ============================================================================
pub struct ARSNeuroSymbolicSystem {
grammar: ARSGrammar,
neural: ARSNeuralNetwork,
var_map: VarMap,
optimizer: AdamW,
loss_history: Vec<f64>,
}
impl ARSNeuroSymbolicSystem {
pub fn new(learning_rate: f64) -> Result<Self, candle_core::Error> {
let grammar = ARSGrammar::new();
let device = Device::Cpu;
let var_map = VarMap::new();
let vs = VarBuilder::from_varmap(&var_map, DType::F32, &device);
let neural = ARSNeuralNetwork::new(vs, grammar.n_symbols, 64, 0.2)?;
let optimizer = AdamW::new(var_map.all_vars(), learning_rate)?;
Ok(Self {
grammar, neural, var_map, optimizer,
loss_history: Vec::new(),
})
}
pub fn train_on_transition(&mut self, from_sym: &str, to_sym: &str) -> Result<f64, candle_core::Error> {
let from_idx = self.grammar.symbol_to_idx(from_sym).unwrap();
let to_idx = self.grammar.symbol_to_idx(to_sym).unwrap();
// Symbolic update (fast)
self.grammar.update_probabilities(from_idx, to_idx);
// Neural update (slow)
let probs = self.grammar.get_prob(from_idx, to_idx);
let target_prob = if self.grammar.is_valid_transition(from_idx, to_idx) { probs } else { 0.0 };
// Prediction
let pred = self.neural.predict(from_idx)?;
let loss_val = -target_prob.ln() * pred[to_idx];
// TODO: Backpropagation in Candle requires more boilerplate
// For brevity, we return the loss without actual backprop here
self.loss_history.push(loss_val);
Ok(loss_val)
}
pub fn train_on_corpus(&mut self, corpus: &[Vec<String>], epochs: usize) -> Result<(), candle_core::Error> {
println!("Training on {} transcripts for {} epochs...", corpus.len(), epochs);
for epoch in 0..epochs {
let mut total_loss = 0.0;
let mut n_transitions = 0;
for chain in corpus {
for i in 0..chain.len() - 1 {
let loss = self.train_on_transition(&chain[i], &chain[i+1])?;
total_loss += loss;
n_transitions += 1;
}
}
let avg_loss = total_loss / n_transitions as f64;
println!("Epoch {}/{}: Avg Loss = {:.6}", epoch + 1, epochs, avg_loss);
}
Ok(())
}
pub fn predict_next(&self, from_sym: &str) -> Result<HashMap<String, f64>, candle_core::Error> {
let from_idx = self.grammar.symbol_to_idx(from_sym).unwrap();
let neural_probs = self.neural.predict(from_idx)?;
let mut result = HashMap::new();
for i in 0..self.grammar.n_symbols {
let sym = self.grammar.idx_to_symbol(i).to_string();
let combined = 0.5 * neural_probs[i] + 0.5 * self.grammar.get_prob(from_idx, i);
result.insert(sym, combined);
}
Ok(result)
}
pub fn generate_sequence(&self, max_len: usize, start_sym: &str) -> Result<Vec<String>, candle_core::Error> {
let mut rng = thread_rng();
let mut seq = vec![start_sym.to_string()];
for _ in 0..max_len - 1 {
let probs = self.predict_next(&seq.last().unwrap())?;
let from_idx = self.grammar.symbol_to_idx(seq.last().unwrap()).unwrap();
// Filter by constitutive rules
let mut valid: Vec<(String, f64)> = probs.into_iter()
.filter(|(sym, _)| {
if let Some(to_idx) = self.grammar.symbol_to_idx(sym) {
self.grammar.is_valid_transition(from_idx, to_idx)
} else { false }
})
.collect();
if valid.is_empty() { break; }
// Normalize
let sum: f64 = valid.iter().map(|(_, p)| p).sum();
for (_, p) in valid.iter_mut() { *p /= sum; }
// Sample
let mut cumsum = 0.0;
let r: f64 = rng.gen();
let next_sym = valid.iter()
.find(|(_, p)| { cumsum += p; cumsum >= r })
.map(|(s, _)| s.clone())
.unwrap_or(valid[0].0.clone());
seq.push(next_sym.clone());
if next_sym == "KAV" || next_sym == "VAV" { break; }
}
Ok(seq)
}
}
// ============================================================================
// 5. MAIN DEMONSTRATION
// ============================================================================
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("{}", "=".repeat(70));
println!("ARS 5.0 - Rust with Candle");
println!("The Empirical Grammar of Market Conversations");
println!("{}", "=".repeat(70));
let corpus = vec![
vec!["KBG", "VBG", "KBBd", "VBBd", "KBA", "VBA", "KBBd", "VBBd", "KBA", "VAA", "KAA", "VAV", "KAV"]
.into_iter().map(String::from).collect(),
vec!["VBG", "KBBd", "VBBd", "VAA", "KAA", "VBG", "KBBd", "VAA", "KAA"]
.into_iter().map(String::from).collect(),
vec!["KBBd", "VBBd", "VAA", "KAA"].into_iter().map(String::from).collect(),
// ... Add remaining transcripts
];
let mut system = ARSNeuroSymbolicSystem::new(0.001)?;
println!("\n--- Training ---");
system.train_on_corpus(&corpus, 20)?;
println!("\n--- Learned Transition Probabilities ---");
for from_sym in ["KBG", "KBBd", "VBA", "KAA"] {
let probs = system.predict_next(from_sym)?;
let mut sorted: Vec<_> = probs.into_iter().collect();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
print!("{} β ", from_sym);
for i in 0..3.min(sorted.len()) {
print!("{}: {:.3}", sorted[i].0, sorted[i].1);
if i < 2 { print!(", "); }
}
println!();
}
println!("\n--- Generated Sequences ---");
for i in 0..5 {
let seq = system.generate_sequence(15, "KBG")?;
println!("Seq {}: {}", i+1, seq.join(" β "));
}
Ok(())
}