## 题目喵述

$$(\sum_{i=1}^n \sum_{j=1}^nijgcd(i,j)) \% p$$

## 题解

$$\sum_{d=1}^nd\sum_{i=1}^n\sum_{j=1}^n[gcd(i,j)==d]ij$$

$$\sum_{d=1}^n d^3\sum_{i=1}^{\frac n d}\sum_{j=1}^{\frac n d}[gcd(i,j)==1]ij$$

$$f(d)=\sum_{i=1}^n\sum_{j=1}^nij[gcd(i,j)==d]$$

$$sum(x)=\frac {(1+x) \times x} 2$$

\begin{align} g(d)&=\sum_{d|n}f(n)\\ &=\sum_{i=1}^n\sum_{j=1}^nij[d|gcd(i,j)]\\ &=\sum_{i=1}^{\frac n d}\sum_{j=1}^{\frac n d}d^2ij\\ &=d^2sum^2(\frac n d) \end{align}

$$f(1)=\sum_{d=1}^n \mu(d)g(d)$$

$$ans = \sum_{d=1}^n d^3\sum_{p=1}^{\frac n d}\mu(p)p^2sum^2(\frac n {pd})$$

$$\sum_{T=1}^n sum^2(\frac n {pd})\sum_{d|T} d^3 (\frac T d)^2 \mu(\frac T d)$$

$$\sum_{T=1}^n sum^2(\frac n T)T^2\sum_{d|T} d \mu(\frac T d)$$

$$\sum_{T=1}^n sum^2(\frac n T)T^2\phi(T)$$

$$f(n)=n^2\phi(n)$$

$$s(n)=\sum_{i=1}^n f(i)$$

$$g(n)=n^2$$

\begin{align} g(n)&=(f*g)(n) – \sum_{d=2}^n f(d)g(\frac n d)\\ &= \sum_{i=1}^n i^3 – g(d)s(\frac n d) \end{align}

$$\sum_{i=1}^n i^3 = \frac {i^2 (i+1)^2}{4}$$

$$\sum_{i=1}^n i^2 = \frac {(2n+1)(n+1)n} 6$$

#include<bits/stdc++.h>
#define fo(i, a, b) for (int i = (a); i <= (b); ++i)

#define N 10000005
#define ll long long
std::bitset<N> vis;
ll phi[N], p[N], tot, mod, up;
std::map<ll, ll> mp;
inline void init ()
{
up = N - 3;
phi[1] = 1;
fo (i, 2, up)
{
if (!vis[i])
p[++tot] = i, phi[i] = i - 1;
vis[i] = 1;
for (int j = 1; j <= tot && i * p[j] <= up; ++j)
{
vis[i * p[j]] = 1;
if (!(i % p[j]))
{
phi[i * p[j]] = phi[i] * p[j];
break;
}
phi[i * p[j]] = phi[i] * phi[p[j]];
}
}
fo (i, 2, up) phi[i] = (phi[i] * i % mod * i + phi[i - 1]) % mod;
}
ll n, m, inv6;
inline ll pow (ll x, ll y = mod - 2)
{
ll ret = 1;
while (y)
{
if (y & 1) ret = ret * x % mod;
x = x * x % mod;
y >>= 1;
}
return ret;
}
inline ll sqr (ll x) {return x * x % mod;}
inline ll sum (ll x) {x %= mod; return ((x + 1) * x >> 1) % mod;}
inline ll md (ll x) {x %= mod; return (x < 0) ? x + mod : x;}
inline ll cube (ll x) {x %= mod; return (x << 1 | 1) % mod * (x + 1) % mod * x % mod * inv6 % mod;}
inline ll s (ll x)
{
if (x <= up) return phi[x];
if (mp[x]) return mp[x];
ll ret = sqr(sum(x));
for (ll i = 2, j; i <= x; i = j + 1)
{
j = x / (x / i);
ret = (ret - (cube(j) - cube(i - 1)) % mod * s(x / i)) % mod;
}
return mp[x] = md(ret);
}
int main()
{
scanf("%lld %lld", &mod, &n);
init(); inv6 = pow(6);
ll ans = 0;
for (ll i = 1, j; i <= n; i = j + 1)
{
j = n / (n / i);
ans = (ans + sqr(sum(n / i)) * (s(j) - s(i - 1))) % mod;
}
printf("%lld\n", md(ans));
return 0;
}