accurating
144def fit( 145 data: MatchResultArrays, 146 config: Config, 147) -> Model: 148 """Fits the model to data according to config. 149 The time complexity is O(match_count * player_count * max(season) * steps) 150 """ 151 if config.do_log: 152 print(config) 153 p1_win_probs = data.p1_win_prob 154 p1s = data.p1 155 p2s = data.p2 156 seasons = data.season 157 158 p1_win_probs = (1 - config.smoothing) * \ 159 p1_win_probs + config.smoothing * 0.5 160 p2_win_probs = 1.0 - p1_win_probs 161 162 player_count = int(jnp.maximum(jnp.max(p1s), jnp.max(p2s)) + 1) 163 season_count = int(jnp.max(seasons) + 1) 164 165 (data_size,) = p1s.shape 166 assert seasons.shape == (data_size,) 167 assert p1s.shape == (data_size,) 168 assert p2s.shape == (data_size,) 169 assert p1_win_probs.shape == (data_size,) 170 171 winner_prior = config.winner_prior_rating / config.rating_difference_for_2_to_1_odds 172 loser_prior = config.loser_prior_rating / config.rating_difference_for_2_to_1_odds 173 174 def get_ratings(p): 175 return p['season_rating'] + p['shared_rating'] 176 177 def model(params): 178 log_likelihood = 0.0 179 ratings = get_ratings(params) 180 assert ratings.shape == (player_count, season_count) 181 p1_ratings = ratings[p1s, seasons] 182 p2_ratings = ratings[p2s, seasons] 183 184 assert p1_ratings.shape == (data_size,) 185 assert p2_ratings.shape == (data_size,) 186 187 # We need to sum instead of averaging, because the more data we have, the more should it outweigh the priors 188 # and even the season_rating_stability. 189 mean_log_data_prob = jnp.sum(log_data_prob(p1_ratings, p2_ratings, p1_win_probs, p2_win_probs)) 190 log_likelihood += mean_log_data_prob 191 192 if config.season_rating_stability > 0.0: 193 log_likelihood -= config.season_rating_stability * jnp.sum((ratings[:, 1:] - ratings[:, :-1])**2) 194 195 if config.winner_prior_match_count > 0.0: 196 log_likelihood += jnp.sum(log_data_prob(ratings, jnp.ones_like(ratings) * winner_prior, 0.0, config.winner_prior_match_count)) 197 198 if config.loser_prior_match_count > 0.0: 199 log_likelihood += jnp.sum(log_data_prob(ratings, jnp.ones_like(ratings) * loser_prior, config.loser_prior_match_count, 0.0)) 200 201 geomean_data_prob = jnp.exp2(mean_log_data_prob / data_size) 202 return log_likelihood / data_size, geomean_data_prob 203 204 # TODO: This is an experiment trying to evaluate ELO playing consistency. Try again and delete if does not work. 205 # cons = params['consistency'] 206 # p1_cons = jnp.take(cons, p1s) 207 # p2_cons = jnp.take(cons, p2s) 208 # winner_win_prob_log = 0.0 209 # winner_win_prob_log += p1_win_probs * log_win_prob_diff(diff/jnp.exp(p1_cons)) + p2_win_probs * log_win_prob_diff(-diff/jnp.exp(p1_cons)) 210 # winner_win_prob_log += p1_win_probs * log_win_prob_diff(diff/jnp.exp(p2_cons)) + p2_win_probs * log_win_prob_diff(-diff/jnp.exp(p2_cons)) 211 # winner_win_prob_log /= 2 212 # return jnp.sum(winner_win_prob_log) - 0.005*jnp.sum(cons ** 2) # or mean? 213 214 # Optimize for these params: 215 shared_rating = jnp.zeros([player_count, 1], dtype=jnp.float64) + (loser_prior + winner_prior) / 2.0 216 season_rating = jnp.zeros([player_count, season_count], dtype=jnp.float64) 217 params = { 'season_rating': season_rating, 'shared_rating': shared_rating } 218 # 'consistency': jnp.zeros([player_count, season_count]), 219 220 # Momentum gradient descent with restarts 221 m_lr = 1.0 222 lr = float(config.initial_lr) 223 momentum = tree_map(jnp.zeros_like, params) 224 last_params = params 225 last_eval = -1e8 # eval of initial data is -1, but regularizations might push it lower. 226 last_grad = tree_map(jnp.zeros_like, params) 227 last_reset_step = 0 228 229 230 for i in range(config.max_steps): 231 (eval, model_fit), grad = jax.value_and_grad(model, has_aux=True)(params) 232 233 if False: 234 # Standard batch gradient descent algorithm works too. Just use good LR. 235 params = tree_map(lambda p, g: p + lr * g, params, grad) 236 else: 237 if eval < last_eval: 238 if config.do_log: 239 print(f'reset to {jnp.exp2(last_eval)}') 240 lr /= 1.5 241 if last_reset_step == i-1: 242 lr /= 4 243 last_reset_step = i 244 momentum = tree_map(jnp.zeros_like, params) 245 # momentum /= 2. 246 params, eval, grad = last_params, last_eval, last_grad 247 else: 248 last_params, last_eval, last_grad = params, eval, grad 249 momentum = tree_map(lambda m, g: m_lr * m + g, momentum, grad) 250 params = tree_map(lambda p, m: p + lr * m, params, momentum) 251 252 max_d_rating = jnp.max( 253 jnp.abs(get_ratings(params) - get_ratings(last_params))) 254 255 if config.do_log: 256 g = get_ratings(grad) 257 g = jnp.sqrt(jnp.mean(g*g)) 258 print( 259 f'Step {i:4}: eval={jnp.exp2(eval):0.12f} pred_power={model_fit:0.6f} lr={lr: 4.4f} grad={g:2.8f} delta={max_d_rating}') 260 261 if max_d_rating < 1e-15: 262 break 263 264 lr *= 1.5 ** (1.0 / 12) 265 266 def postprocess(): 267 rating = {} 268 for id, name in enumerate(data.player_name): 269 rating[name] = {} 270 for season in range(season_count): 271 rating[name][season] = float(get_ratings(params)[id, season]) * config.rating_difference_for_2_to_1_odds 272 model = Model(rating=rating) 273 if config.do_log: 274 print(model.tabulate()) 275 return model 276 277 return postprocess()
Fits the model to data according to config. The time complexity is O(match_count * player_count * max(season) * steps)
279def data_from_dicts(matches) -> MatchResultArrays: 280 player_set = set() 281 282 for match in matches: 283 player_set.add(match['p1']) 284 player_set.add(match['p2']) 285 assert match['winner'] == match['p1'] or match['winner'] == match['p2'], match 286 assert isinstance(match['season'], int) 287 288 player_name = sorted(list(player_set)) 289 290 p1 = [] 291 p2 = [] 292 p1_win_prob = [] 293 season = [] 294 295 for match in matches: 296 p1.append(player_name.index(match['p1'])) 297 p2.append(player_name.index(match['p2'])) 298 p1_win = match['winner'] == match['p1'] 299 p1_win_prob.append(1.0 if p1_win else 0.0) 300 season.append(match['season']) 301 302 return MatchResultArrays( 303 p1=np.array(p1), 304 p2=np.array(p2), 305 p1_win_prob=np.array(p1_win_prob), 306 season=np.array(season), 307 player_name=player_name, 308 )
39@dataclasses.dataclass 40class MatchResultArrays: 41 """Match data for AccuRating in numpy arrays. 42 All attributes have a shape (match_count,). 43 """ 44 45 p1: np.ndarray 46 """Player 1 id (small integer).""" 47 48 p2: np.ndarray 49 """Player 2 id (small integer).""" 50 51 p1_win_prob: np.ndarray 52 """1.0 if p1 wins, 0.0 if p2 wins. Can be any number in [0.0, 1.0].""" 53 54 season: np.ndarray 55 """Currently the seasons have to be small integers.""" 56 57 player_name: list[str] | None 58 """Indexed with player id. Not used in the training."""
Match data for AccuRating in numpy arrays. All attributes have a shape (match_count,).
18def win_prob(rating, opp_rating): 19 """Probability of win for given ratings.""" 20 return 1.0 / (1.0 + jnp.exp2(opp_rating-rating)) 21 # This is more understandable and equivalent: 22 # return jnp.exp2(rating) / (jnp.exp2(rating) + jnp.exp2(opp_rating))
Probability of win for given ratings.
61@dataclasses.dataclass 62class Config: 63 """AccuRating configuration.""" 64 65 season_rating_stability: float 66 """Rating stability across seasons. 67 68 Currently the seasons have to be small integers. 69 season_rating_stability = 0 means that ratings at each season are completly separate. 70 season_rating_stability = inf means that ratings at each season should be the same.""" 71 72 smoothing: float 73 """ Balance between match results and player pairings as the sources of data. 74 There are two sources of data: 75 - Match result: Winner probably has a higher rating than the looser. 76 - Player pairing: Matched players probably have similar strength. 77 Setting smoothing to 0.0 ignorse player pairing as and would rely on the match result only. 78 Setting smoothing to 1.0 ignorse match result would rely on player pairing only. 79 80 Typically, in the absence of data ratings assume a prior that the skill of a player some fixed value like 1000. 81 This allows the rating to not escape to infinity when only losses or only wins are available. 82 Smoothing essentially allows to specify that the looser (in every match) had a small chance of winning. 83 This is also known as 'label smoothing'.""" 84 85 winner_prior_rating: float = 4000.0 86 winner_prior_match_count: float = 0.0 87 loser_prior_rating: float = 1000.0 88 loser_prior_match_count: float = 0.0 89 """Adds two virtual players with a fixed ratings of winner_prior_rating and loser_prior_rating that will always win and always lose. 90 Adds to the data set, for every player and *every season*, winner_prior_match_count (loser_prior_match_count) games with them. 91 The match_counts should be much smaller than the actual number of matches that players played. 92 If match_counts are set to 0.0 the prior is disabled and so the resulting ratings float (can be shifted as a whole by a constant). 93 """ 94 95 max_steps: int = 1_000_000 96 """Limits the number of passes over the dataset.""" 97 98 do_log: bool = False 99 """Enables additional logging.""" 100 101 initial_lr: float = 10000.0 102 """It is automatically adjusted, but sometimes it is too large and blows up.""" 103 104 rating_difference_for_2_to_1_odds: float = 100.0 105 """That many points difference creates 2:1 win odds. 106 Twice the difference predicts 5:1 odds. 107 You can change it to 120.412 to match chess ELO scale. 108 Apart from rescaling the final result, it also rescales prior_ratings in this config above."""
AccuRating configuration.
Rating stability across seasons.
Currently the seasons have to be small integers. season_rating_stability = 0 means that ratings at each season are completly separate. season_rating_stability = inf means that ratings at each season should be the same.
Balance between match results and player pairings as the sources of data. There are two sources of data:
- Match result: Winner probably has a higher rating than the looser.
- Player pairing: Matched players probably have similar strength. Setting smoothing to 0.0 ignorse player pairing as and would rely on the match result only. Setting smoothing to 1.0 ignorse match result would rely on player pairing only.
Typically, in the absence of data ratings assume a prior that the skill of a player some fixed value like 1000. This allows the rating to not escape to infinity when only losses or only wins are available. Smoothing essentially allows to specify that the looser (in every match) had a small chance of winning. This is also known as 'label smoothing'.
Adds two virtual players with a fixed ratings of winner_prior_rating and loser_prior_rating that will always win and always lose. Adds to the data set, for every player and every season, winner_prior_match_count (loser_prior_match_count) games with them. The match_counts should be much smaller than the actual number of matches that players played. If match_counts are set to 0.0 the prior is disabled and so the resulting ratings float (can be shifted as a whole by a constant).
111@dataclasses.dataclass 112class Model: 113 """Trained model.""" 114 115 rating: dict[str, dict[int, float]] 116 """Player rating, indexed by name and season""" 117 118 def tabulate(self): 119 last_rating = [] 120 min_season, max_season = None, None 121 for name, ratings in self.rating.items(): 122 assert min_season in [None, min(ratings.keys())] 123 assert max_season in [None, max(ratings.keys())] 124 min_season = min(ratings.keys()) 125 max_season = max(ratings.keys()) 126 last_rating.append((ratings[max_season], name)) 127 if min_season == None: 128 return "" 129 min_season += 1 # Skip season no 0 130 last_rating.sort(reverse=True) 131 headers = ['Nick'] 132 for season in range(max_season, min_season-1, -1): 133 headers.append(f'S{season}') 134 table = [] 135 for _, name in last_rating: 136 # if len(table) > 10: break # max rows 137 row = [name] 138 for season in range(max_season, min_season-1, -1): 139 row.append(self.rating[name][season]) 140 table.append(row) 141 return tabulate(table, headers=headers, floatfmt=".1f", numalign="decimal")
Trained model.
118 def tabulate(self): 119 last_rating = [] 120 min_season, max_season = None, None 121 for name, ratings in self.rating.items(): 122 assert min_season in [None, min(ratings.keys())] 123 assert max_season in [None, max(ratings.keys())] 124 min_season = min(ratings.keys()) 125 max_season = max(ratings.keys()) 126 last_rating.append((ratings[max_season], name)) 127 if min_season == None: 128 return "" 129 min_season += 1 # Skip season no 0 130 last_rating.sort(reverse=True) 131 headers = ['Nick'] 132 for season in range(max_season, min_season-1, -1): 133 headers.append(f'S{season}') 134 table = [] 135 for _, name in last_rating: 136 # if len(table) > 10: break # max rows 137 row = [name] 138 for season in range(max_season, min_season-1, -1): 139 row.append(self.rating[name][season]) 140 table.append(row) 141 return tabulate(table, headers=headers, floatfmt=".1f", numalign="decimal")