Untitled
unknown
python
2 years ago
4.9 kB
11
Indexable
import math
import time
TIME_LIMIT = 1
def flatten(A):
v = []
for row in A:
for e in row:
v.append(e)
return v
def matrix_mult(A, B):
# The number of columns in A has to be the same as the number of rows in B, in order for AB to be defined
if len(A[0]) != len(B):
raise Exception("The dimensions of A and B do not match!")
C = [[0] * len(B[0]) for _ in range(len(A))]
for i in range(len(A)):
for j in range(len(B[0])):
c_acc = 0
for k in range(len(A[0])):
c_acc += A[i][k] * B[k][j]
C[i][j] = c_acc
return C
def transform_flattened_to_matrix(A, n, k):
B = [[0] * k for _ in range(n)]
for i in range(n):
for j in range(k):
B[i][j] = A[i * k + j]
return B
def calc_scaled_alpha(A, B, pi, O):
N = len(A)
T = len(O)
alpha = [[0]*N for _ in range(T)]
c = [0]*T
# Get alpha_1
for i in range(N):
alpha[0][i] = B[i][O[0]]*pi[i]
c[0] += alpha[0][i]
# Scale alpha_1
alpha[0] = [alpha[0][i]/c[0] for i in range(len(alpha[0]))]
for t in range(1, T):
for i in range(N):
temp = 0
for j in range(N):
temp += A[j][i]*alpha[t-1][j]
alpha[t][i] = B[i][O[t]]*temp
c[t] += alpha[t][i]
alpha[t] = [alpha[t][i]/c[t] for i in range(len(alpha[t]))]
return alpha, c
def calc_scaled_beta(A, B, O, c):
N = len(A)
T = len(O)
beta = [[1/c[-1]]*N for _ in range(T)]
for t in range(T-2, -1, -1):
for i in range(N):
temp_beta = 0
for j in range(N):
temp_beta += beta[t+1][j]*B[j][O[t+1]]*A[i][j]
beta[t][i] = temp_beta/c[t]
return beta
def calc_gamma(A, B, O, alpha, beta):
N = len(A)
T = len(O)
di_gamma = [[[0 for _ in range(N)] for _ in range(N)] for _ in range(T)]
gamma = [[0 for _ in range(N)] for _ in range(T)]
for t in range(T - 1):
for i in range(N):
for j in range(N):
di_gamma[t][i][j] = alpha[t][i] * A[i][j] * B[j][O[t + 1]] * beta[t + 1][j]
for i in range(N):
gamma[t][i] = sum(di_gamma[t][i])
for i in range(N):
gamma[T-1][i] = alpha[T-1][i]
return di_gamma, gamma
def estimate_lambda(A, B, pi, O):
N = len(A)
T = len(O)
K = len(B[0])
alpha, c = calc_scaled_alpha(A, B, pi, O)
beta = calc_scaled_beta(A, B, O, c)
di_gamma, gamma = calc_gamma(A, B, O, alpha, beta)
A_upd = [[0 for _ in range(N)] for _ in range(N)]
B_upd = [[0 for _ in range(K)] for _ in range(N)]
pi_upd = [0]*len(pi)
for i in range(N):
pi_upd[i] = gamma[0][i]
for i in range(N):
sum_gamma = 0
for t in range(T - 1):
sum_gamma += gamma[t][i]
for j in range(N):
sum_di_gamma = 0
for t in range(T - 1):
sum_di_gamma += di_gamma[t][i][j]
A_upd[i][j] = sum_di_gamma / sum_gamma
for j in range(N):
sum_gamma = 0
for t in range(T):
sum_gamma += gamma[t][j]
for k in range(K):
sum_indicator = 0
for t in range(T):
sum_indicator += indicator(O[t], k)*gamma[t][j]
B_upd[j][k] = sum_indicator / sum_gamma
return A_upd, B_upd, pi_upd, c
def indicator(a, b):
return 1 if a == b else 0
def do_baum_welch(A, B, pi, O, precision, START_TIME):
oldLogPorb = -math.inf
converged = False
it = 0
while not converged and time.time() - START_TIME < TIME_LIMIT*0.85:
A, B, pi, c = estimate_lambda(A, B, pi, O)
it += 1
logProb = -sum(math.log(1/round(c_i, precision+1)) for c_i in c)
if logProb <= oldLogPorb:
converged = True
oldLogPorb = logProb
return A, B, pi
def stringify_matrix(A, precision):
n = len(A)
k = len(A[0])
flattened_A = flatten(A)
flattened_A.insert(0, k)
flattened_A.insert(0, n)
return " ".join(str(round(e, precision)) for e in flattened_A)
def estimate_model_parameters():
precision = 7
START_TIME = time.time()
A_guess_flat = [float(x) for x in input().split()[2:]]
B_guess_flat = [float(x) for x in input().split()[2:]]
pi_guess = [float(x) for x in input().split()[2:]]
O = [int(x) for x in input().split()[1:]]
n = int(math.sqrt(len(A_guess_flat)))
k = int(len(B_guess_flat) / n)
A = transform_flattened_to_matrix(A_guess_flat, n, n)
B = transform_flattened_to_matrix(B_guess_flat, n, k)
A, B, pi = do_baum_welch(A, B, pi_guess, O, precision, START_TIME)
return stringify_matrix(A, precision), stringify_matrix(B, precision)
if __name__ == "__main__":
stringified_A, stringified_B = estimate_model_parameters()
print(stringified_A)
print(stringified_B)Editor is loading...
Leave a Comment